Table of Contents

Class RLDataLoaderBase<T>

Namespace
AiDotNet.Data.Loaders
Assembly
AiDotNet.dll

Abstract base class for RL data loaders providing common reinforcement learning functionality.

public abstract class RLDataLoaderBase<T> : DataLoaderBase<T>, IRLDataLoader<T>, IDataLoader<T>, IResettable, ICountable, IBatchIterable<Experience<T, Vector<T>, Vector<T>>>

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
RLDataLoaderBase<T>
Implements
IBatchIterable<Experience<T, Vector<T>, Vector<T>>>
Derived
Inherited Members
Extension Methods

Remarks

RLDataLoaderBase provides shared implementation for all RL data loaders including: - Environment interaction management - Replay buffer management - Episode running and experience collection - Batch sampling for training

For Beginners: This base class handles common RL operations: - Stepping through the environment collecting experiences - Storing experiences in a replay buffer - Sampling batches for training

Concrete implementations extend this to work with specific environments or provide specialized experience collection strategies.

Constructors

RLDataLoaderBase(IEnvironment<T>, IReplayBuffer<T, Vector<T>, Vector<T>>, int, int, int, bool, int?)

Initializes a new instance of the RLDataLoaderBase class.

protected RLDataLoaderBase(IEnvironment<T> environment, IReplayBuffer<T, Vector<T>, Vector<T>> replayBuffer, int episodes = 1000, int maxStepsPerEpisode = 500, int minExperiencesBeforeTraining = 1000, bool verbose = true, int? seed = null)

Parameters

environment IEnvironment<T>

The RL environment to interact with.

replayBuffer IReplayBuffer<T, Vector<T>, Vector<T>>

The replay buffer for storing experiences.

episodes int

Total number of episodes for training.

maxStepsPerEpisode int

Maximum steps per episode (prevents infinite loops).

minExperiencesBeforeTraining int

Minimum experiences needed before training can start.

verbose bool

Whether to print progress to console.

seed int?

Optional random seed for reproducibility.

Fields

NumOps

Numeric operations helper for type T.

protected static readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

BatchSize

Gets or sets the batch size for iteration.

public override int BatchSize { get; set; }

Property Value

int

CurrentEpisode

Gets the current episode number (0-indexed).

public int CurrentEpisode { get; }

Property Value

int

Environment

Gets the environment that the agent interacts with.

public IEnvironment<T> Environment { get; }

Property Value

IEnvironment<T>

Episodes

Gets the total number of episodes to run during training.

public int Episodes { get; }

Property Value

int

HasNext

Gets whether there are more batches available in the current iteration.

public bool HasNext { get; }

Property Value

bool

MaxStepsPerEpisode

Gets the maximum number of steps per episode (prevents infinite episodes).

public int MaxStepsPerEpisode { get; }

Property Value

int

MinExperiencesBeforeTraining

Gets the minimum number of experiences required before training can begin.

public int MinExperiencesBeforeTraining { get; }

Property Value

int

Remarks

For Beginners: We need some experiences before we can learn from random samples. This ensures the replay buffer has enough diverse experiences for effective learning.

ReplayBuffer

Gets the replay buffer used for storing and sampling experiences.

public IReplayBuffer<T, Vector<T>, Vector<T>> ReplayBuffer { get; }

Property Value

IReplayBuffer<T, Vector<T>, Vector<T>>

TotalCount

Gets the total number of samples in the dataset.

public override int TotalCount { get; }

Property Value

int

TotalSteps

Gets the total number of steps taken across all episodes.

public int TotalSteps { get; }

Property Value

int

Verbose

Gets whether to print training progress to console.

public bool Verbose { get; }

Property Value

bool

Methods

AddExperience(Experience<T, Vector<T>, Vector<T>>)

Adds an experience to the replay buffer.

public void AddExperience(Experience<T, Vector<T>, Vector<T>> experience)

Parameters

experience Experience<T, Vector<T>, Vector<T>>

The experience to add.

CanTrain(int)

Checks if there are enough experiences to begin training.

public bool CanTrain(int batchSize)

Parameters

batchSize int

The desired batch size for training.

Returns

bool

True if training can proceed, false if more experiences are needed.

GetBatches(int?, bool, bool, int?)

Iterates through all batches in the dataset using lazy evaluation.

public virtual IEnumerable<Experience<T, Vector<T>, Vector<T>>> GetBatches(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null)

Parameters

batchSize int?

Optional batch size override. Uses default BatchSize if null.

shuffle bool

Whether to shuffle data before batching. Default is true.

dropLast bool

Whether to drop the last incomplete batch. Default is false.

seed int?

Optional random seed for reproducible shuffling.

Returns

IEnumerable<Experience<T, Vector<T>, Vector<T>>>

An enumerable sequence of batches using yield return for memory efficiency.

Remarks

This method provides a PyTorch-style iteration pattern using IEnumerable and yield return for memory-efficient lazy evaluation. Each call creates a fresh iteration, automatically handling reset and shuffle operations.

For Beginners: This is the recommended way to iterate through your data:

foreach (var (xBatch, yBatch) in dataLoader.GetBatches(batchSize: 32, shuffle: true))
{
    // Train on this batch
    model.TrainOnBatch(xBatch, yBatch);
}

Unlike GetNextBatch(), you don't need to call Reset() - each GetBatches() call starts fresh. The yield return pattern means batches are generated on-demand, not all loaded into memory at once.

GetBatchesAsync(int?, bool, bool, int?, int, CancellationToken)

Asynchronously iterates through all batches with prefetching support.

public virtual IAsyncEnumerable<Experience<T, Vector<T>, Vector<T>>> GetBatchesAsync(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null, int prefetchCount = 2, CancellationToken cancellationToken = default)

Parameters

batchSize int?

Optional batch size override. Uses default BatchSize if null.

shuffle bool

Whether to shuffle data before batching. Default is true.

dropLast bool

Whether to drop the last incomplete batch. Default is false.

seed int?

Optional random seed for reproducible shuffling.

prefetchCount int

Number of batches to prefetch ahead. Default is 2.

cancellationToken CancellationToken

Token to cancel the iteration.

Returns

IAsyncEnumerable<Experience<T, Vector<T>, Vector<T>>>

An async enumerable sequence of batches.

Remarks

This method enables async batch iteration with configurable prefetching, similar to PyTorch's num_workers or TensorFlow's prefetch(). Batches are prepared in the background while the current batch is being processed.

For Beginners: Use this for large datasets or when batch preparation is slow:

await foreach (var (xBatch, yBatch) in dataLoader.GetBatchesAsync(prefetchCount: 2))
{
    // While training on this batch, the next 2 batches are being prepared
    await model.TrainOnBatchAsync(xBatch, yBatch);
}

Prefetching helps hide data loading latency, especially useful for:

  • Large images that need decoding
  • Data that requires preprocessing
  • Slow storage (network drives, cloud storage)

GetNextBatch()

Gets the next batch of data.

public Experience<T, Vector<T>, Vector<T>> GetNextBatch()

Returns

Experience<T, Vector<T>, Vector<T>>

The next batch of data.

Exceptions

InvalidOperationException

Thrown when no more batches are available.

LoadDataCoreAsync(CancellationToken)

Core data loading implementation to be provided by derived classes.

protected override Task LoadDataCoreAsync(CancellationToken cancellationToken)

Parameters

cancellationToken CancellationToken

Cancellation token for async operation.

Returns

Task

A task that completes when loading is finished.

Remarks

Derived classes must implement this to perform actual data loading: - Load from files, databases, or remote sources - Parse and validate data format - Store in appropriate internal structures

OnReset()

Called after Reset() to allow derived classes to perform additional reset operations.

protected override void OnReset()

Remarks

Override this to reset any domain-specific state. The base indices are already reset when this is called.

ResetTraining()

Resets the data loader state (clears buffer, resets counters).

public void ResetTraining()

RunEpisode(IRLAgent<T>?)

Runs a single episode and collects experiences.

public virtual EpisodeResult<T> RunEpisode(IRLAgent<T>? agent = null)

Parameters

agent IRLAgent<T>

Optional agent to use for action selection. If null, uses random actions.

Returns

EpisodeResult<T>

Episode result containing total reward, steps, and whether it was successful.

RunEpisodes(int, IRLAgent<T>?)

Runs multiple episodes and collects experiences.

public IReadOnlyList<EpisodeResult<T>> RunEpisodes(int numEpisodes, IRLAgent<T>? agent = null)

Parameters

numEpisodes int

Number of episodes to run.

agent IRLAgent<T>

Optional agent to use for action selection.

Returns

IReadOnlyList<EpisodeResult<T>>

List of episode results.

SampleBatch(int)

Samples a batch of experiences from the replay buffer.

public IReadOnlyList<Experience<T, Vector<T>, Vector<T>>> SampleBatch(int batchSize)

Parameters

batchSize int

Number of experiences to sample.

Returns

IReadOnlyList<Experience<T, Vector<T>, Vector<T>>>

List of sampled experiences for training.

Remarks

For Beginners: Instead of learning from experiences in order (which can cause issues), we randomly sample from past experiences. This makes learning more stable.

SelectRandomAction()

Selects a random action for exploration.

protected virtual Vector<T> SelectRandomAction()

Returns

Vector<T>

A random action vector.

Remarks

Thread Safety: This method uses locking to ensure thread-safe access to the random number generator.

SetSeed(int)

Sets the random seed for reproducible training.

public void SetSeed(int seed)

Parameters

seed int

Random seed value.

TryGetNextBatch(out Experience<T, Vector<T>, Vector<T>>)

Attempts to get the next batch without throwing if unavailable.

public bool TryGetNextBatch(out Experience<T, Vector<T>, Vector<T>> batch)

Parameters

batch Experience<T, Vector<T>, Vector<T>>

The batch if available, default otherwise.

Returns

bool

True if a batch was available, false if iteration is complete.

Remarks

When false is returned, batch contains the default value for TBatch. Callers should check the return value before using batch.

UnloadDataCore()

Core data unloading implementation to be provided by derived classes.

protected override void UnloadDataCore()

Remarks

Derived classes should implement this to release resources: - Clear internal data structures - Release file handles or connections - Allow garbage collection of loaded data