Table of Contents

Interface IContinualLearningStrategy<T, TInput, TOutput>

Namespace
AiDotNet.ContinualLearning.Interfaces
Assembly
AiDotNet.dll

Strategy interface for continual learning algorithms.

public interface IContinualLearningStrategy<T, TInput, TOutput>

Type Parameters

T

The numeric type used for calculations.

TInput

The input data type.

TOutput

The output data type.

Remarks

For Beginners: Different continual learning methods use different strategies to prevent forgetting. This interface allows the trainer to work with any strategy.

Strategy Types:

  • Regularization-based: Add penalty terms to protect important weights (EWC, SI, MAS)
  • Replay-based: Store and replay old examples (Experience Replay, GEM)
  • Architecture-based: Use separate parameters for different tasks (Progressive Networks, PackNet)
  • Distillation-based: Use teacher model to preserve old knowledge (LwF)

Reference: De Lange et al. "A Continual Learning Survey: Defying Forgetting" (2021)

Properties

MemoryUsageBytes

Gets the current memory usage of the strategy in bytes.

long MemoryUsageBytes { get; }

Property Value

long

ModifiesArchitecture

Gets whether this strategy modifies the model architecture.

bool ModifiesArchitecture { get; }

Property Value

bool

Name

Gets the name of the strategy.

string Name { get; }

Property Value

string

RequiresMemoryBuffer

Gets whether this strategy requires storing examples from previous tasks.

bool RequiresMemoryBuffer { get; }

Property Value

bool

Methods

AdjustGradients(Vector<T>)

Adjusts gradients to prevent forgetting.

Vector<T> AdjustGradients(Vector<T> gradients)

Parameters

gradients Vector<T>

The gradients from the current task loss.

Returns

Vector<T>

Adjusted gradients that respect previous task constraints.

Remarks

This modifies gradients before they're applied: - GEM: Project gradients to not increase loss on previous tasks - A-GEM: Project onto average gradient of reference samples - OWM: Project gradients into orthogonal subspace

ComputeRegularizationLoss(IFullModel<T, TInput, TOutput>)

Computes the regularization loss to prevent forgetting.

T ComputeRegularizationLoss(IFullModel<T, TInput, TOutput> model)

Parameters

model IFullModel<T, TInput, TOutput>

The current model.

Returns

T

The regularization loss value.

Remarks

This is added to the task loss during training: - EWC: Sum of (current params - optimal params)^2 * Fisher Information - SI: Sum of (current params - optimal params)^2 * importance - MAS: Sum of (current params - optimal params)^2 * gradient magnitude

FinalizeTask(IFullModel<T, TInput, TOutput>)

Finalizes the task after training is complete.

void FinalizeTask(IFullModel<T, TInput, TOutput> model)

Parameters

model IFullModel<T, TInput, TOutput>

The trained model.

Remarks

This is called after training on a task completes. Strategies use this to: - EWC: Compute Fisher Information Matrix - SI: Compute path integral importance - PackNet: Prune and freeze weights

GetIncompatibilityReason(IFullModel<T, TInput, TOutput>)

Gets a description of why the strategy is incompatible with a model.

string? GetIncompatibilityReason(IFullModel<T, TInput, TOutput> model)

Parameters

model IFullModel<T, TInput, TOutput>

The model to check.

Returns

string

Description of incompatibility, or null if compatible.

GetMetrics()

Gets strategy-specific metrics for monitoring.

IReadOnlyDictionary<string, object> GetMetrics()

Returns

IReadOnlyDictionary<string, object>

Dictionary of metric name to value.

IsCompatibleWith(IFullModel<T, TInput, TOutput>)

Validates that the strategy is compatible with the given model.

bool IsCompatibleWith(IFullModel<T, TInput, TOutput> model)

Parameters

model IFullModel<T, TInput, TOutput>

The model to validate against.

Returns

bool

True if compatible, false otherwise.

Load(string)

Loads the strategy state from a file.

void Load(string path)

Parameters

path string

Path to load the state from.

PrepareForTask(IFullModel<T, TInput, TOutput>, IDataset<T, TInput, TOutput>)

Prepares the strategy for learning a new task.

void PrepareForTask(IFullModel<T, TInput, TOutput> model, IDataset<T, TInput, TOutput> taskData)

Parameters

model IFullModel<T, TInput, TOutput>

The model being trained.

taskData IDataset<T, TInput, TOutput>

The data for the new task.

Remarks

This is called at the start of training on a new task. Strategies use this to: - EWC: Store current parameters as optimal for previous tasks - LwF: Create a copy of the model as the teacher - GEM: Compute reference gradients from stored examples

Reset()

Resets the strategy to its initial state.

void Reset()

Save(string)

Saves the strategy state to a file.

void Save(string path)

Parameters

path string

Path to save the state.