Class ExperienceReplayBuffer<T, TInput, TOutput>
- Namespace
- AiDotNet.ContinualLearning.Memory
- Assembly
- AiDotNet.dll
A memory buffer for storing examples from previous tasks for experience replay.
public class ExperienceReplayBuffer<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations.
TInputThe input data type.
TOutputThe output data type.
- Inheritance
-
ExperienceReplayBuffer<T, TInput, TOutput>
- Inherited Members
Remarks
For Beginners: Experience replay stores a small number of examples from previous tasks and intermixes them with new task data during training. This helps prevent catastrophic forgetting by reminding the model of what it learned before.
Why It Works: When a neural network learns a new task, it adjusts its weights to minimize error on that task. Without replay, these adjustments can increase error on previous tasks. By mixing old examples with new training data, the network must maintain performance on both old and new tasks.
Memory Management: Since storing all examples is impractical, we use smart sampling strategies to select the most representative examples. Different strategies work better for different scenarios:
- Reservoir: Fair representation, good for general use
- ClassBalanced: Equal class representation, best for imbalanced data
- Herding: Class-mean exemplars, from iCaRL paper
References:
- Chaudhry et al. "Continual Learning with Tiny Episodic Memories" (2019)
- Rebuffi et al. "iCaRL: Incremental Classifier and Representation Learning" (2017)
- Rolnick et al. "Experience Replay for Continual Learning" (2019)
Constructors
ExperienceReplayBuffer(int, MemorySamplingStrategy, ReplaySamplingStrategy, int?)
Initializes a new experience replay buffer.
public ExperienceReplayBuffer(int maxSize, MemorySamplingStrategy addStrategy = MemorySamplingStrategy.Reservoir, ReplaySamplingStrategy replayStrategy = ReplaySamplingStrategy.TaskBalanced, int? seed = null)
Parameters
maxSizeintMaximum number of examples to store.
addStrategyMemorySamplingStrategyStrategy for selecting which examples to add.
replayStrategyReplaySamplingStrategyStrategy for sampling during replay.
seedint?Random seed for reproducibility.
Properties
AddStrategy
Gets the sampling strategy used when adding examples.
public MemorySamplingStrategy AddStrategy { get; }
Property Value
Count
Gets the current number of stored examples.
public int Count { get; }
Property Value
EstimatedMemoryBytes
Gets the estimated memory usage in bytes.
public long EstimatedMemoryBytes { get; }
Property Value
IsFull
Gets whether the buffer is at capacity.
public bool IsFull { get; }
Property Value
MaxSize
Gets the maximum capacity of the buffer.
public int MaxSize { get; }
Property Value
ReplayStrategy
Gets the sampling strategy used during replay.
public ReplaySamplingStrategy ReplayStrategy { get; }
Property Value
TaskCount
Gets the number of distinct tasks represented in the buffer.
public int TaskCount { get; }
Property Value
TotalReplaySamples
Gets the total number of samples returned via replay.
public int TotalReplaySamples { get; }
Property Value
TotalSamplesProcessed
Gets the total number of samples processed (added) since creation.
public int TotalSamplesProcessed { get; }
Property Value
Methods
AddTaskExamples(IDataset<T, TInput, TOutput>, int, int?)
Adds examples from a task to the buffer using the configured sampling strategy.
public void AddTaskExamples(IDataset<T, TInput, TOutput> taskData, int taskId, int? samplesPerTask = null)
Parameters
taskDataIDataset<T, TInput, TOutput>The task data to sample from.
taskIdintThe task identifier.
samplesPerTaskint?Number of samples to store from this task. If null, distributes evenly.
Clear()
Clears all stored examples.
public void Clear()
GetAll()
Gets all stored examples.
public IReadOnlyList<DataPoint<T, TInput, TOutput>> GetAll()
Returns
- IReadOnlyList<DataPoint<T, TInput, TOutput>>
GetStatistics()
Gets statistics about the buffer.
public BufferStatistics GetStatistics()
Returns
GetTaskCounts()
Gets the count of examples per task.
public IReadOnlyDictionary<int, int> GetTaskCounts()
Returns
GetTaskExamples(int)
Gets all examples for a specific task.
public IReadOnlyList<DataPoint<T, TInput, TOutput>> GetTaskExamples(int taskId)
Parameters
taskIdintThe task identifier.
Returns
- IReadOnlyList<DataPoint<T, TInput, TOutput>>
Read-only list of examples for the task.
RemoveTask(int)
Removes all examples from a specific task.
public void RemoveTask(int taskId)
Parameters
taskIdintThe task to remove.
SampleBatch(int)
Samples a batch of examples from the buffer using the configured replay strategy.
public List<DataPoint<T, TInput, TOutput>> SampleBatch(int batchSize)
Parameters
batchSizeintNumber of examples to sample.
Returns
SampleFromTask(int, int)
Samples examples from a specific task.
public List<DataPoint<T, TInput, TOutput>> SampleFromTask(int taskId, int count)
Parameters
Returns
UpdatePriority(int, double)
Updates the priority of a sample (for priority-based replay).
public void UpdatePriority(int index, double priority)