Class RLDataLoaderBase<T>
Abstract base class for RL data loaders providing common reinforcement learning functionality.
public abstract class RLDataLoaderBase<T> : DataLoaderBase<T>, IRLDataLoader<T>, IDataLoader<T>, IResettable, ICountable, IBatchIterable<Experience<T, Vector<T>, Vector<T>>>
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
RLDataLoaderBase<T>
- Implements
-
IDataLoader<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
RLDataLoaderBase provides shared implementation for all RL data loaders including: - Environment interaction management - Replay buffer management - Episode running and experience collection - Batch sampling for training
For Beginners: This base class handles common RL operations: - Stepping through the environment collecting experiences - Storing experiences in a replay buffer - Sampling batches for training
Concrete implementations extend this to work with specific environments or provide specialized experience collection strategies.
Constructors
RLDataLoaderBase(IEnvironment<T>, IReplayBuffer<T, Vector<T>, Vector<T>>, int, int, int, bool, int?)
Initializes a new instance of the RLDataLoaderBase class.
protected RLDataLoaderBase(IEnvironment<T> environment, IReplayBuffer<T, Vector<T>, Vector<T>> replayBuffer, int episodes = 1000, int maxStepsPerEpisode = 500, int minExperiencesBeforeTraining = 1000, bool verbose = true, int? seed = null)
Parameters
environmentIEnvironment<T>The RL environment to interact with.
replayBufferIReplayBuffer<T, Vector<T>, Vector<T>>The replay buffer for storing experiences.
episodesintTotal number of episodes for training.
maxStepsPerEpisodeintMaximum steps per episode (prevents infinite loops).
minExperiencesBeforeTrainingintMinimum experiences needed before training can start.
verboseboolWhether to print progress to console.
seedint?Optional random seed for reproducibility.
Fields
NumOps
Numeric operations helper for type T.
protected static readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Properties
BatchSize
Gets or sets the batch size for iteration.
public override int BatchSize { get; set; }
Property Value
CurrentEpisode
Gets the current episode number (0-indexed).
public int CurrentEpisode { get; }
Property Value
Environment
Gets the environment that the agent interacts with.
public IEnvironment<T> Environment { get; }
Property Value
- IEnvironment<T>
Episodes
Gets the total number of episodes to run during training.
public int Episodes { get; }
Property Value
HasNext
Gets whether there are more batches available in the current iteration.
public bool HasNext { get; }
Property Value
MaxStepsPerEpisode
Gets the maximum number of steps per episode (prevents infinite episodes).
public int MaxStepsPerEpisode { get; }
Property Value
MinExperiencesBeforeTraining
Gets the minimum number of experiences required before training can begin.
public int MinExperiencesBeforeTraining { get; }
Property Value
Remarks
For Beginners: We need some experiences before we can learn from random samples. This ensures the replay buffer has enough diverse experiences for effective learning.
ReplayBuffer
Gets the replay buffer used for storing and sampling experiences.
public IReplayBuffer<T, Vector<T>, Vector<T>> ReplayBuffer { get; }
Property Value
- IReplayBuffer<T, Vector<T>, Vector<T>>
TotalCount
Gets the total number of samples in the dataset.
public override int TotalCount { get; }
Property Value
TotalSteps
Gets the total number of steps taken across all episodes.
public int TotalSteps { get; }
Property Value
Verbose
Gets whether to print training progress to console.
public bool Verbose { get; }
Property Value
Methods
AddExperience(Experience<T, Vector<T>, Vector<T>>)
Adds an experience to the replay buffer.
public void AddExperience(Experience<T, Vector<T>, Vector<T>> experience)
Parameters
experienceExperience<T, Vector<T>, Vector<T>>The experience to add.
CanTrain(int)
Checks if there are enough experiences to begin training.
public bool CanTrain(int batchSize)
Parameters
batchSizeintThe desired batch size for training.
Returns
- bool
True if training can proceed, false if more experiences are needed.
GetBatches(int?, bool, bool, int?)
Iterates through all batches in the dataset using lazy evaluation.
public virtual IEnumerable<Experience<T, Vector<T>, Vector<T>>> GetBatches(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null)
Parameters
batchSizeint?Optional batch size override. Uses default BatchSize if null.
shuffleboolWhether to shuffle data before batching. Default is true.
dropLastboolWhether to drop the last incomplete batch. Default is false.
seedint?Optional random seed for reproducible shuffling.
Returns
- IEnumerable<Experience<T, Vector<T>, Vector<T>>>
An enumerable sequence of batches using yield return for memory efficiency.
Remarks
This method provides a PyTorch-style iteration pattern using IEnumerable and yield return for memory-efficient lazy evaluation. Each call creates a fresh iteration, automatically handling reset and shuffle operations.
For Beginners: This is the recommended way to iterate through your data:
foreach (var (xBatch, yBatch) in dataLoader.GetBatches(batchSize: 32, shuffle: true))
{
// Train on this batch
model.TrainOnBatch(xBatch, yBatch);
}
Unlike GetNextBatch(), you don't need to call Reset() - each GetBatches() call starts fresh. The yield return pattern means batches are generated on-demand, not all loaded into memory at once.
GetBatchesAsync(int?, bool, bool, int?, int, CancellationToken)
Asynchronously iterates through all batches with prefetching support.
public virtual IAsyncEnumerable<Experience<T, Vector<T>, Vector<T>>> GetBatchesAsync(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null, int prefetchCount = 2, CancellationToken cancellationToken = default)
Parameters
batchSizeint?Optional batch size override. Uses default BatchSize if null.
shuffleboolWhether to shuffle data before batching. Default is true.
dropLastboolWhether to drop the last incomplete batch. Default is false.
seedint?Optional random seed for reproducible shuffling.
prefetchCountintNumber of batches to prefetch ahead. Default is 2.
cancellationTokenCancellationTokenToken to cancel the iteration.
Returns
- IAsyncEnumerable<Experience<T, Vector<T>, Vector<T>>>
An async enumerable sequence of batches.
Remarks
This method enables async batch iteration with configurable prefetching, similar to PyTorch's num_workers or TensorFlow's prefetch(). Batches are prepared in the background while the current batch is being processed.
For Beginners: Use this for large datasets or when batch preparation is slow:
await foreach (var (xBatch, yBatch) in dataLoader.GetBatchesAsync(prefetchCount: 2))
{
// While training on this batch, the next 2 batches are being prepared
await model.TrainOnBatchAsync(xBatch, yBatch);
}
Prefetching helps hide data loading latency, especially useful for:
- Large images that need decoding
- Data that requires preprocessing
- Slow storage (network drives, cloud storage)
GetNextBatch()
Gets the next batch of data.
public Experience<T, Vector<T>, Vector<T>> GetNextBatch()
Returns
- Experience<T, Vector<T>, Vector<T>>
The next batch of data.
Exceptions
- InvalidOperationException
Thrown when no more batches are available.
LoadDataCoreAsync(CancellationToken)
Core data loading implementation to be provided by derived classes.
protected override Task LoadDataCoreAsync(CancellationToken cancellationToken)
Parameters
cancellationTokenCancellationTokenCancellation token for async operation.
Returns
- Task
A task that completes when loading is finished.
Remarks
Derived classes must implement this to perform actual data loading: - Load from files, databases, or remote sources - Parse and validate data format - Store in appropriate internal structures
OnReset()
Called after Reset() to allow derived classes to perform additional reset operations.
protected override void OnReset()
Remarks
Override this to reset any domain-specific state. The base indices are already reset when this is called.
ResetTraining()
Resets the data loader state (clears buffer, resets counters).
public void ResetTraining()
RunEpisode(IRLAgent<T>?)
Runs a single episode and collects experiences.
public virtual EpisodeResult<T> RunEpisode(IRLAgent<T>? agent = null)
Parameters
agentIRLAgent<T>Optional agent to use for action selection. If null, uses random actions.
Returns
- EpisodeResult<T>
Episode result containing total reward, steps, and whether it was successful.
RunEpisodes(int, IRLAgent<T>?)
Runs multiple episodes and collects experiences.
public IReadOnlyList<EpisodeResult<T>> RunEpisodes(int numEpisodes, IRLAgent<T>? agent = null)
Parameters
numEpisodesintNumber of episodes to run.
agentIRLAgent<T>Optional agent to use for action selection.
Returns
- IReadOnlyList<EpisodeResult<T>>
List of episode results.
SampleBatch(int)
Samples a batch of experiences from the replay buffer.
public IReadOnlyList<Experience<T, Vector<T>, Vector<T>>> SampleBatch(int batchSize)
Parameters
batchSizeintNumber of experiences to sample.
Returns
- IReadOnlyList<Experience<T, Vector<T>, Vector<T>>>
List of sampled experiences for training.
Remarks
For Beginners: Instead of learning from experiences in order (which can cause issues), we randomly sample from past experiences. This makes learning more stable.
SelectRandomAction()
Selects a random action for exploration.
protected virtual Vector<T> SelectRandomAction()
Returns
- Vector<T>
A random action vector.
Remarks
Thread Safety: This method uses locking to ensure thread-safe access to the random number generator.
SetSeed(int)
Sets the random seed for reproducible training.
public void SetSeed(int seed)
Parameters
seedintRandom seed value.
TryGetNextBatch(out Experience<T, Vector<T>, Vector<T>>)
Attempts to get the next batch without throwing if unavailable.
public bool TryGetNextBatch(out Experience<T, Vector<T>, Vector<T>> batch)
Parameters
batchExperience<T, Vector<T>, Vector<T>>The batch if available, default otherwise.
Returns
- bool
True if a batch was available, false if iteration is complete.
Remarks
When false is returned, batch contains the default value for TBatch. Callers should check the return value before using batch.
UnloadDataCore()
Core data unloading implementation to be provided by derived classes.
protected override void UnloadDataCore()
Remarks
Derived classes should implement this to release resources: - Clear internal data structures - Release file handles or connections - Allow garbage collection of loaded data