Table of Contents

Class InputOutputDataLoaderBase<T, TInput, TOutput>

Namespace
AiDotNet.Data.Loaders
Assembly
AiDotNet.dll

Abstract base class for input-output data loaders providing common supervised learning functionality.

public abstract class InputOutputDataLoaderBase<T, TInput, TOutput> : DataLoaderBase<T>, IInputOutputDataLoader<T, TInput, TOutput>, IDataLoader<T>, IResettable, ICountable, IBatchIterable<(TInput Features, TOutput Labels)>, IShuffleable

Type Parameters

T

The numeric type used for calculations, typically float or double.

TInput

The input data type (e.g., Matrix<T>, Tensor<T>).

TOutput

The output data type (e.g., Vector<T>, Tensor<T>).

Inheritance
InputOutputDataLoaderBase<T, TInput, TOutput>
Implements
IInputOutputDataLoader<T, TInput, TOutput>
IBatchIterable<(TInput Features, TOutput Labels)>
Derived
Inherited Members
Extension Methods

Remarks

InputOutputDataLoaderBase provides shared implementation for all supervised learning data loaders including: - Feature (X) and label (Y) data management - Train/validation/test splitting - Shuffling and batching capabilities - Progress tracking through ICountable

For Beginners: This base class handles common input-output data operations: - Storing features (X) and labels (Y) for supervised learning - Splitting data into training, validation, and test sets - Shuffling data to improve training - Iterating through data in batches

Concrete implementations (CsvDataLoader, ImageDataLoader) extend this to load specific data formats.

Fields

Indices

Indices for current data ordering (used for shuffling).

protected int[]? Indices

Field Value

int[]

LoadedFeatures

Storage for loaded feature data.

protected TInput? LoadedFeatures

Field Value

TInput

LoadedLabels

Storage for loaded label data.

protected TOutput? LoadedLabels

Field Value

TOutput

NumOps

Numeric operations helper for type T.

protected static readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

BatchSize

Gets or sets the batch size for iteration.

public override int BatchSize { get; set; }

Property Value

int

FeatureCount

Gets the number of features per sample.

public abstract int FeatureCount { get; }

Property Value

int

Features

Gets all input features as a single data structure.

public TInput Features { get; }

Property Value

TInput

Remarks

This provides access to the complete feature set. For large datasets, prefer using batch iteration methods instead of loading everything at once.

HasNext

Gets whether there are more batches available in the current iteration.

public bool HasNext { get; }

Property Value

bool

IsShuffled

Gets whether the data is currently shuffled.

public bool IsShuffled { get; }

Property Value

bool

Labels

Gets all output labels as a single data structure.

public TOutput Labels { get; }

Property Value

TOutput

Remarks

This provides access to all labels. For large datasets, prefer using batch iteration methods instead of loading everything at once.

OutputDimension

Gets the number of output dimensions (1 for regression/binary classification, N for multi-class with N classes).

public abstract int OutputDimension { get; }

Property Value

int

Methods

ComputeSplitSizes(int, double, double)

Computes split sizes from ratios and total count.

protected static (int TrainSize, int ValidationSize, int TestSize) ComputeSplitSizes(int totalCount, double trainRatio, double validationRatio)

Parameters

totalCount int

Total number of samples.

trainRatio double

Training ratio.

validationRatio double

Validation ratio.

Returns

(int TrainSize, int ValidationSize, int TestSize)

A tuple containing train, validation, and test sizes.

ExtractBatch(int[])

Extracts a batch of features and labels at the specified indices.

protected abstract (TInput Features, TOutput Labels) ExtractBatch(int[] indices)

Parameters

indices int[]

The indices of samples to extract.

Returns

(TInput Input, TOutput Output)

A tuple containing the features and labels for the batch.

Remarks

Derived classes must implement this to extract data based on their specific TInput and TOutput types.

GetBatches(int?, bool, bool, int?)

Iterates through all batches in the dataset using lazy evaluation.

public virtual IEnumerable<(TInput Features, TOutput Labels)> GetBatches(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null)

Parameters

batchSize int?

Optional batch size override. Uses default BatchSize if null.

shuffle bool

Whether to shuffle data before batching. Default is true.

dropLast bool

Whether to drop the last incomplete batch. Default is false.

seed int?

Optional random seed for reproducible shuffling.

Returns

IEnumerable<(TInput Features, TOutput Labels)>

An enumerable sequence of batches using yield return for memory efficiency.

Remarks

This method provides a PyTorch-style iteration pattern using IEnumerable and yield return for memory-efficient lazy evaluation. Each call creates a fresh iteration, automatically handling reset and shuffle operations.

For Beginners: This is the recommended way to iterate through your data:

foreach (var (xBatch, yBatch) in dataLoader.GetBatches(batchSize: 32, shuffle: true))
{
    // Train on this batch
    model.TrainOnBatch(xBatch, yBatch);
}

Unlike GetNextBatch(), you don't need to call Reset() - each GetBatches() call starts fresh. The yield return pattern means batches are generated on-demand, not all loaded into memory at once.

GetBatchesAsync(int?, bool, bool, int?, int, CancellationToken)

Asynchronously iterates through all batches with prefetching support.

public virtual IAsyncEnumerable<(TInput Features, TOutput Labels)> GetBatchesAsync(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null, int prefetchCount = 2, CancellationToken cancellationToken = default)

Parameters

batchSize int?

Optional batch size override. Uses default BatchSize if null.

shuffle bool

Whether to shuffle data before batching. Default is true.

dropLast bool

Whether to drop the last incomplete batch. Default is false.

seed int?

Optional random seed for reproducible shuffling.

prefetchCount int

Number of batches to prefetch ahead. Default is 2.

cancellationToken CancellationToken

Token to cancel the iteration.

Returns

IAsyncEnumerable<(TInput Features, TOutput Labels)>

An async enumerable sequence of batches.

Remarks

This method enables async batch iteration with configurable prefetching, similar to PyTorch's num_workers or TensorFlow's prefetch(). Batches are prepared in the background while the current batch is being processed.

For Beginners: Use this for large datasets or when batch preparation is slow:

await foreach (var (xBatch, yBatch) in dataLoader.GetBatchesAsync(prefetchCount: 2))
{
    // While training on this batch, the next 2 batches are being prepared
    await model.TrainOnBatchAsync(xBatch, yBatch);
}

Prefetching helps hide data loading latency, especially useful for:

  • Large images that need decoding
  • Data that requires preprocessing
  • Slow storage (network drives, cloud storage)

GetNextBatch()

Gets the next batch of data.

public (TInput Features, TOutput Labels) GetNextBatch()

Returns

(TInput Input, TOutput Output)

The next batch of data.

Exceptions

InvalidOperationException

Thrown when no more batches are available.

InitializeIndices(int)

Initializes indices array after data is loaded.

protected void InitializeIndices(int count)

Parameters

count int

The number of samples in the dataset.

OnReset()

Called after Reset() to allow derived classes to perform additional reset operations.

protected override void OnReset()

Remarks

Override this to reset any domain-specific state. The base indices are already reset when this is called.

Shuffle(int?)

Shuffles the data order using the specified seed for reproducibility.

public virtual void Shuffle(int? seed = null)

Parameters

seed int?

Optional seed for reproducible shuffling. Same seed produces same order.

Remarks

For Beginners: The seed is like a "recipe" for the randomness. Using the same seed gives the same "random" order every time, which is useful for reproducing experiments or debugging.

Split(double, double, int?)

Creates a train/validation/test split of the data.

public abstract (IInputOutputDataLoader<T, TInput, TOutput> Train, IInputOutputDataLoader<T, TInput, TOutput> Validation, IInputOutputDataLoader<T, TInput, TOutput> Test) Split(double trainRatio = 0.7, double validationRatio = 0.15, int? seed = null)

Parameters

trainRatio double

Fraction of data for training (0.0 to 1.0).

validationRatio double

Fraction of data for validation (0.0 to 1.0).

seed int?

Optional random seed for reproducible splits.

Returns

(IInputOutputDataLoader<T, TInput, TOutput> Train, IInputOutputDataLoader<T, TInput, TOutput> Validation, IInputOutputDataLoader<T, TInput, TOutput> Test)

A tuple containing three data loaders: (train, validation, test).

Remarks

The test ratio is implicitly 1 - trainRatio - validationRatio.

For Beginners: Splitting data is crucial for evaluating your model: - **Training set**: Data the model learns from - **Validation set**: Data used to tune hyperparameters and prevent overfitting - **Test set**: Data used only once at the end to get an unbiased performance estimate

Common splits are 60/20/20 or 70/15/15 (train/validation/test).

TryGetNextBatch(out (TInput Features, TOutput Labels))

Attempts to get the next batch without throwing if unavailable.

public bool TryGetNextBatch(out (TInput Features, TOutput Labels) batch)

Parameters

batch (TInput Input, TOutput Output)

The batch if available, default otherwise.

Returns

bool

True if a batch was available, false if iteration is complete.

Remarks

When false is returned, batch contains the default value for TBatch. Callers should check the return value before using batch.

Unshuffle()

Restores the original (unshuffled) data order.

public virtual void Unshuffle()

ValidateSplitRatios(double, double)

Validates split ratios.

protected static void ValidateSplitRatios(double trainRatio, double validationRatio)

Parameters

trainRatio double

Training ratio.

validationRatio double

Validation ratio.

Exceptions

ArgumentException

Thrown when ratios are invalid.