Table of Contents

Class ExperienceReplayBuffer<T, TInput, TOutput>

Namespace
AiDotNet.ContinualLearning.Memory
Assembly
AiDotNet.dll

A memory buffer for storing examples from previous tasks for experience replay.

public class ExperienceReplayBuffer<T, TInput, TOutput>

Type Parameters

T

The numeric type used for calculations.

TInput

The input data type.

TOutput

The output data type.

Inheritance
ExperienceReplayBuffer<T, TInput, TOutput>
Inherited Members

Remarks

For Beginners: Experience replay stores a small number of examples from previous tasks and intermixes them with new task data during training. This helps prevent catastrophic forgetting by reminding the model of what it learned before.

Why It Works: When a neural network learns a new task, it adjusts its weights to minimize error on that task. Without replay, these adjustments can increase error on previous tasks. By mixing old examples with new training data, the network must maintain performance on both old and new tasks.

Memory Management: Since storing all examples is impractical, we use smart sampling strategies to select the most representative examples. Different strategies work better for different scenarios:

  • Reservoir: Fair representation, good for general use
  • ClassBalanced: Equal class representation, best for imbalanced data
  • Herding: Class-mean exemplars, from iCaRL paper

References:

  • Chaudhry et al. "Continual Learning with Tiny Episodic Memories" (2019)
  • Rebuffi et al. "iCaRL: Incremental Classifier and Representation Learning" (2017)
  • Rolnick et al. "Experience Replay for Continual Learning" (2019)

Constructors

ExperienceReplayBuffer(int, MemorySamplingStrategy, ReplaySamplingStrategy, int?)

Initializes a new experience replay buffer.

public ExperienceReplayBuffer(int maxSize, MemorySamplingStrategy addStrategy = MemorySamplingStrategy.Reservoir, ReplaySamplingStrategy replayStrategy = ReplaySamplingStrategy.TaskBalanced, int? seed = null)

Parameters

maxSize int

Maximum number of examples to store.

addStrategy MemorySamplingStrategy

Strategy for selecting which examples to add.

replayStrategy ReplaySamplingStrategy

Strategy for sampling during replay.

seed int?

Random seed for reproducibility.

Properties

AddStrategy

Gets the sampling strategy used when adding examples.

public MemorySamplingStrategy AddStrategy { get; }

Property Value

MemorySamplingStrategy

Count

Gets the current number of stored examples.

public int Count { get; }

Property Value

int

EstimatedMemoryBytes

Gets the estimated memory usage in bytes.

public long EstimatedMemoryBytes { get; }

Property Value

long

IsFull

Gets whether the buffer is at capacity.

public bool IsFull { get; }

Property Value

bool

MaxSize

Gets the maximum capacity of the buffer.

public int MaxSize { get; }

Property Value

int

ReplayStrategy

Gets the sampling strategy used during replay.

public ReplaySamplingStrategy ReplayStrategy { get; }

Property Value

ReplaySamplingStrategy

TaskCount

Gets the number of distinct tasks represented in the buffer.

public int TaskCount { get; }

Property Value

int

TotalReplaySamples

Gets the total number of samples returned via replay.

public int TotalReplaySamples { get; }

Property Value

int

TotalSamplesProcessed

Gets the total number of samples processed (added) since creation.

public int TotalSamplesProcessed { get; }

Property Value

int

Methods

AddTaskExamples(IDataset<T, TInput, TOutput>, int, int?)

Adds examples from a task to the buffer using the configured sampling strategy.

public void AddTaskExamples(IDataset<T, TInput, TOutput> taskData, int taskId, int? samplesPerTask = null)

Parameters

taskData IDataset<T, TInput, TOutput>

The task data to sample from.

taskId int

The task identifier.

samplesPerTask int?

Number of samples to store from this task. If null, distributes evenly.

Clear()

Clears all stored examples.

public void Clear()

GetAll()

Gets all stored examples.

public IReadOnlyList<DataPoint<T, TInput, TOutput>> GetAll()

Returns

IReadOnlyList<DataPoint<T, TInput, TOutput>>

GetStatistics()

Gets statistics about the buffer.

public BufferStatistics GetStatistics()

Returns

BufferStatistics

GetTaskCounts()

Gets the count of examples per task.

public IReadOnlyDictionary<int, int> GetTaskCounts()

Returns

IReadOnlyDictionary<int, int>

GetTaskExamples(int)

Gets all examples for a specific task.

public IReadOnlyList<DataPoint<T, TInput, TOutput>> GetTaskExamples(int taskId)

Parameters

taskId int

The task identifier.

Returns

IReadOnlyList<DataPoint<T, TInput, TOutput>>

Read-only list of examples for the task.

RemoveTask(int)

Removes all examples from a specific task.

public void RemoveTask(int taskId)

Parameters

taskId int

The task to remove.

SampleBatch(int)

Samples a batch of examples from the buffer using the configured replay strategy.

public List<DataPoint<T, TInput, TOutput>> SampleBatch(int batchSize)

Parameters

batchSize int

Number of examples to sample.

Returns

List<DataPoint<T, TInput, TOutput>>

A list of sampled data points.

SampleFromTask(int, int)

Samples examples from a specific task.

public List<DataPoint<T, TInput, TOutput>> SampleFromTask(int taskId, int count)

Parameters

taskId int

The task to sample from.

count int

Number of examples to sample.

Returns

List<DataPoint<T, TInput, TOutput>>

List of data points from the specified task.

UpdatePriority(int, double)

Updates the priority of a sample (for priority-based replay).

public void UpdatePriority(int index, double priority)

Parameters

index int

The buffer index of the sample.

priority double

The new priority value (higher = more likely to be sampled).