Class StratifiedSampler
A sampler that ensures each class is represented proportionally in each epoch.
public class StratifiedSampler : DataSamplerBase, IStratifiedSampler, IDataSampler
- Inheritance
-
StratifiedSampler
- Implements
- Inherited Members
Remarks
StratifiedSampler maintains the class distribution from the original dataset during sampling. This is especially important for imbalanced datasets where some classes have many more samples than others.
For Beginners: When your dataset has unequal class sizes (e.g., 90% cats, 10% dogs), random sampling might sometimes produce batches with only cats. Stratified sampling ensures every batch has a similar ratio of cats to dogs as the full dataset.
Example:
// Labels: [0, 0, 0, 1, 1, 2] (3 samples of class 0, 2 of class 1, 1 of class 2)
var sampler = new StratifiedSampler(labels, numClasses: 3);
// Each epoch will maintain the 3:2:1 ratio while shuffling within classes
Constructors
StratifiedSampler(IEnumerable<int>, int, int?)
Initializes a new instance of the StratifiedSampler class.
public StratifiedSampler(IEnumerable<int> labels, int numClasses, int? seed = null)
Parameters
labelsIEnumerable<int>The class label for each sample (0-indexed).
numClassesintThe total number of classes.
seedint?Optional random seed for reproducibility.
Exceptions
- ArgumentNullException
Thrown when labels is null.
- ArgumentOutOfRangeException
Thrown when numClasses is less than 2.
Properties
Labels
Gets or sets the class labels for each sample.
public IReadOnlyList<int> Labels { get; set; }
Property Value
Length
Gets the total number of samples this sampler will produce per epoch.
public override int Length { get; }
Property Value
Remarks
This may differ from the dataset size for oversampling or undersampling strategies.
NumClasses
Gets the number of unique classes.
public int NumClasses { get; }
Property Value
Methods
GetIndicesCore()
Core implementation for generating indices. Override this in derived classes.
protected override IEnumerable<int> GetIndicesCore()
Returns
- IEnumerable<int>
An enumerable of sample indices.