Table of Contents

Class StratifiedBatchSampler

Namespace
AiDotNet.Data.Sampling
Assembly
AiDotNet.dll

A batch sampler that ensures each batch contains samples from all classes.

public class StratifiedBatchSampler : DataSamplerBase, IBatchSampler, IStratifiedSampler, IDataSampler
Inheritance
StratifiedBatchSampler
Implements
Inherited Members

Remarks

StratifiedBatchSampler creates batches where each batch has approximately the same class distribution. This is useful for batch normalization layers or when batch-level statistics are important.

For Beginners: While StratifiedSampler ensures the overall epoch has the right class balance, StratifiedBatchSampler ensures EACH BATCH has balanced classes. This is helpful when: - Using batch normalization (needs balanced statistics per batch) - Doing contrastive learning (needs diverse samples in each batch)

Constructors

StratifiedBatchSampler(IEnumerable<int>, int, int, bool, int?)

Initializes a new instance of the StratifiedBatchSampler class.

public StratifiedBatchSampler(IEnumerable<int> labels, int numClasses, int batchSize, bool dropLast = false, int? seed = null)

Parameters

labels IEnumerable<int>

The class label for each sample (0-indexed).

numClasses int

The total number of classes.

batchSize int

The batch size.

dropLast bool

Whether to drop the last incomplete batch.

seed int?

Optional random seed for reproducibility.

Properties

BatchSize

Gets or sets the batch size for batch-level sampling.

public int BatchSize { get; set; }

Property Value

int

DropLast

Gets or sets whether to drop the last incomplete batch.

public bool DropLast { get; set; }

Property Value

bool

Labels

Gets or sets the class labels for each sample.

public IReadOnlyList<int> Labels { get; set; }

Property Value

IReadOnlyList<int>

Length

Gets the total number of samples this sampler will produce per epoch.

public override int Length { get; }

Property Value

int

Remarks

This may differ from the dataset size for oversampling or undersampling strategies.

NumClasses

Gets the number of unique classes.

public int NumClasses { get; }

Property Value

int

Methods

GetBatchIndices()

Returns an enumerable of index arrays, where each array represents one batch.

public IEnumerable<int[]> GetBatchIndices()

Returns

IEnumerable<int[]>

An enumerable of batch index arrays.

GetIndicesCore()

Core implementation for generating indices. Override this in derived classes.

protected override IEnumerable<int> GetIndicesCore()

Returns

IEnumerable<int>

An enumerable of sample indices.