Table of Contents

Class StratifiedSampler

Namespace
AiDotNet.Data.Sampling
Assembly
AiDotNet.dll

A sampler that ensures each class is represented proportionally in each epoch.

public class StratifiedSampler : DataSamplerBase, IStratifiedSampler, IDataSampler
Inheritance
StratifiedSampler
Implements
Inherited Members

Remarks

StratifiedSampler maintains the class distribution from the original dataset during sampling. This is especially important for imbalanced datasets where some classes have many more samples than others.

For Beginners: When your dataset has unequal class sizes (e.g., 90% cats, 10% dogs), random sampling might sometimes produce batches with only cats. Stratified sampling ensures every batch has a similar ratio of cats to dogs as the full dataset.

Example:

// Labels: [0, 0, 0, 1, 1, 2] (3 samples of class 0, 2 of class 1, 1 of class 2)
var sampler = new StratifiedSampler(labels, numClasses: 3);
// Each epoch will maintain the 3:2:1 ratio while shuffling within classes

Constructors

StratifiedSampler(IEnumerable<int>, int, int?)

Initializes a new instance of the StratifiedSampler class.

public StratifiedSampler(IEnumerable<int> labels, int numClasses, int? seed = null)

Parameters

labels IEnumerable<int>

The class label for each sample (0-indexed).

numClasses int

The total number of classes.

seed int?

Optional random seed for reproducibility.

Exceptions

ArgumentNullException

Thrown when labels is null.

ArgumentOutOfRangeException

Thrown when numClasses is less than 2.

Properties

Labels

Gets or sets the class labels for each sample.

public IReadOnlyList<int> Labels { get; set; }

Property Value

IReadOnlyList<int>

Length

Gets the total number of samples this sampler will produce per epoch.

public override int Length { get; }

Property Value

int

Remarks

This may differ from the dataset size for oversampling or undersampling strategies.

NumClasses

Gets the number of unique classes.

public int NumClasses { get; }

Property Value

int

Methods

GetIndicesCore()

Core implementation for generating indices. Override this in derived classes.

protected override IEnumerable<int> GetIndicesCore()

Returns

IEnumerable<int>

An enumerable of sample indices.