Class OptimizationDataBatcher<T, TInput, TOutput>
- Namespace
- AiDotNet.Optimizers
- Assembly
- AiDotNet.dll
Provides batch iteration utilities for optimization input data.
public class OptimizationDataBatcher<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations.
TInputThe type of input data for the model.
TOutputThe type of output data for the model.
- Inheritance
-
OptimizationDataBatcher<T, TInput, TOutput>
- Inherited Members
Remarks
OptimizationDataBatcher provides efficient batch iteration over optimization input data, integrating with the DataLoader batching infrastructure for consistent behavior.
For Beginners: When training machine learning models, you typically don't feed all your data at once. Instead, you break it into smaller "batches" and train on each batch. This class makes that process easy and efficient:
var batcher = new OptimizationDataBatcher<float, Matrix<float>, Vector<float>>(
inputData, batchSize: 32, shuffle: true);
foreach (var (xBatch, yBatch, indices) in batcher.GetBatches())
{
// Train on this batch
var gradient = CalculateGradient(xBatch, yBatch);
UpdateParameters(gradient);
}
Constructors
OptimizationDataBatcher(OptimizationInputData<T, TInput, TOutput>, int, bool, bool, int?, IDataSampler?)
Initializes a new instance of the OptimizationDataBatcher class.
public OptimizationDataBatcher(OptimizationInputData<T, TInput, TOutput> inputData, int batchSize, bool shuffle = true, bool dropLast = false, int? seed = null, IDataSampler? sampler = null)
Parameters
inputDataOptimizationInputData<T, TInput, TOutput>The optimization input data containing training, validation, and test sets.
batchSizeintNumber of samples per batch.
shuffleboolWhether to shuffle the data before batching. Default is true.
dropLastboolWhether to drop the last incomplete batch. Default is false.
seedint?Optional random seed for reproducibility.
samplerIDataSamplerOptional custom sampler for advanced sampling strategies.
Properties
BatchSize
Gets the batch size.
public int BatchSize { get; }
Property Value
DataSize
Gets the total number of samples in the training data.
public int DataSize { get; }
Property Value
NumBatches
Gets the number of batches per epoch.
public int NumBatches { get; }
Property Value
Methods
GetBatchIndices()
Iterates through training data in batches, returning only the indices.
public IEnumerable<int[]> GetBatchIndices()
Returns
- IEnumerable<int[]>
An enumerable of index arrays for each batch.
Remarks
This is useful when you need to calculate gradients or perform operations that only require the indices, not the actual data.
GetBatches()
Iterates through training data in batches.
public IEnumerable<(TInput XBatch, TOutput YBatch, int[] Indices)> GetBatches()
Returns
- IEnumerable<(TInput XBatch, TOutput YBatch, int[] Indices)>
An enumerable of tuples containing batch inputs, outputs, and indices.
Remarks
For Beginners: Each iteration gives you: - xBatch: The input features for this batch - yBatch: The corresponding target values - indices: The original indices of samples in this batch (useful for tracking)
WithClassBalancing<TWeight>(IReadOnlyList<int>, int)
Creates a new batcher with weighted sampling for class balancing.
public OptimizationDataBatcher<T, TInput, TOutput> WithClassBalancing<TWeight>(IReadOnlyList<int> labels, int numClasses)
Parameters
labelsIReadOnlyList<int>The class labels for each sample.
numClassesintThe number of classes.
Returns
- OptimizationDataBatcher<T, TInput, TOutput>
A new OptimizationDataBatcher with weighted sampling.
Type Parameters
TWeightThe numeric type for weights.
WithCurriculumLearning<TDifficulty>(IEnumerable<TDifficulty>, int, CurriculumStrategy)
Creates a new batcher with curriculum learning.
public OptimizationDataBatcher<T, TInput, TOutput> WithCurriculumLearning<TDifficulty>(IEnumerable<TDifficulty> difficulties, int totalEpochs, CurriculumStrategy strategy = CurriculumStrategy.Linear)
Parameters
difficultiesIEnumerable<TDifficulty>Difficulty score for each sample (0 = easiest, 1 = hardest).
totalEpochsintTotal number of epochs for curriculum completion.
strategyCurriculumStrategyThe curriculum progression strategy.
Returns
- OptimizationDataBatcher<T, TInput, TOutput>
A new OptimizationDataBatcher with curriculum learning.
Type Parameters
TDifficultyThe numeric type for difficulty scores.
WithSampler(IDataSampler)
Creates a new batcher with a different sampler.
public OptimizationDataBatcher<T, TInput, TOutput> WithSampler(IDataSampler sampler)
Parameters
samplerIDataSamplerThe new sampler to use.
Returns
- OptimizationDataBatcher<T, TInput, TOutput>
A new OptimizationDataBatcher with the specified sampler.