Class StratifiedBatchSampler
A batch sampler that ensures each batch contains samples from all classes.
public class StratifiedBatchSampler : DataSamplerBase, IBatchSampler, IStratifiedSampler, IDataSampler
- Inheritance
-
StratifiedBatchSampler
- Implements
- Inherited Members
Remarks
StratifiedBatchSampler creates batches where each batch has approximately the same class distribution. This is useful for batch normalization layers or when batch-level statistics are important.
For Beginners: While StratifiedSampler ensures the overall epoch has the right class balance, StratifiedBatchSampler ensures EACH BATCH has balanced classes. This is helpful when: - Using batch normalization (needs balanced statistics per batch) - Doing contrastive learning (needs diverse samples in each batch)
Constructors
StratifiedBatchSampler(IEnumerable<int>, int, int, bool, int?)
Initializes a new instance of the StratifiedBatchSampler class.
public StratifiedBatchSampler(IEnumerable<int> labels, int numClasses, int batchSize, bool dropLast = false, int? seed = null)
Parameters
labelsIEnumerable<int>The class label for each sample (0-indexed).
numClassesintThe total number of classes.
batchSizeintThe batch size.
dropLastboolWhether to drop the last incomplete batch.
seedint?Optional random seed for reproducibility.
Properties
BatchSize
Gets or sets the batch size for batch-level sampling.
public int BatchSize { get; set; }
Property Value
DropLast
Gets or sets whether to drop the last incomplete batch.
public bool DropLast { get; set; }
Property Value
Labels
Gets or sets the class labels for each sample.
public IReadOnlyList<int> Labels { get; set; }
Property Value
Length
Gets the total number of samples this sampler will produce per epoch.
public override int Length { get; }
Property Value
Remarks
This may differ from the dataset size for oversampling or undersampling strategies.
NumClasses
Gets the number of unique classes.
public int NumClasses { get; }
Property Value
Methods
GetBatchIndices()
Returns an enumerable of index arrays, where each array represents one batch.
public IEnumerable<int[]> GetBatchIndices()
Returns
- IEnumerable<int[]>
An enumerable of batch index arrays.
GetIndicesCore()
Core implementation for generating indices. Override this in derived classes.
protected override IEnumerable<int> GetIndicesCore()
Returns
- IEnumerable<int>
An enumerable of sample indices.