Table of Contents

Interface IContinualLearner<T, TInput, TOutput>

Namespace
AiDotNet.ContinualLearning.Interfaces
Assembly
AiDotNet.dll

Interface for continual learning trainers that can learn multiple tasks sequentially.

public interface IContinualLearner<T, TInput, TOutput>

Type Parameters

T

The numeric type used for calculations.

TInput

The input data type.

TOutput

The output data type.

Remarks

For Beginners: Continual learning (also called lifelong learning) is the ability to learn new tasks over time without forgetting previously learned knowledge. Traditional neural networks suffer from "catastrophic forgetting" - when trained on new data, they forget what they learned before.

This interface provides methods to:

  • Learn new tasks while protecting old knowledge
  • Evaluate performance on all learned tasks
  • Save and load model state for resuming training

Common Implementations:

  • EWCTrainer: Uses Elastic Weight Consolidation to protect important weights
  • LwFTrainer: Uses Learning without Forgetting with knowledge distillation
  • GEMTrainer: Uses Gradient Episodic Memory to constrain gradients

References:

  • Parisi et al. "Continual Lifelong Learning with Neural Networks: A Review" (2019)
  • De Lange et al. "A Continual Learning Survey" (2021)

Properties

BaseModel

Gets the underlying model being trained.

IFullModel<T, TInput, TOutput> BaseModel { get; }

Property Value

IFullModel<T, TInput, TOutput>

Config

Gets the configuration for the continual learner.

IContinualLearnerConfig<T> Config { get; }

Property Value

IContinualLearnerConfig<T>

IsTraining

Gets whether the learner is currently training.

bool IsTraining { get; }

Property Value

bool

MemoryUsageBytes

Gets the current memory usage of the learner in bytes.

long MemoryUsageBytes { get; }

Property Value

long

TasksLearned

Gets the number of tasks that have been learned.

int TasksLearned { get; }

Property Value

int

Methods

ComputeForgetting()

Computes the current forgetting metric for all tasks.

IReadOnlyDictionary<int, T> ComputeForgetting()

Returns

IReadOnlyDictionary<int, T>

Dictionary mapping task ID to forgetting amount.

EvaluateAllTasks()

Evaluates the model on all learned tasks.

ContinualEvaluationResult<T> EvaluateAllTasks()

Returns

ContinualEvaluationResult<T>

Comprehensive evaluation result including backward/forward transfer.

EvaluateTask(int, IDataset<T, TInput, TOutput>)

Evaluates the model on a specific task.

TaskEvaluationResult<T> EvaluateTask(int taskId, IDataset<T, TInput, TOutput> testData)

Parameters

taskId int

The task identifier (0-indexed).

testData IDataset<T, TInput, TOutput>

The test data for the task.

Returns

TaskEvaluationResult<T>

Evaluation result for the specified task.

GetAllHistory()

Gets all training history.

IReadOnlyList<ContinualLearningResult<T>> GetAllHistory()

Returns

IReadOnlyList<ContinualLearningResult<T>>

List of all training results.

GetTaskHistory(int)

Gets the training history for a specific task.

ContinualLearningResult<T>? GetTaskHistory(int taskId)

Parameters

taskId int

The task identifier.

Returns

ContinualLearningResult<T>

The training result for the task, or null if not available.

LearnTask(IDataset<T, TInput, TOutput>)

Learns a new task from the provided data.

ContinualLearningResult<T> LearnTask(IDataset<T, TInput, TOutput> taskData)

Parameters

taskData IDataset<T, TInput, TOutput>

The dataset for the new task.

Returns

ContinualLearningResult<T>

Result containing training metrics and performance information.

Remarks

This method trains the model on the new task while using the configured strategy to prevent forgetting of previously learned tasks.

LearnTask(IDataset<T, TInput, TOutput>, IDataset<T, TInput, TOutput>, int?)

Learns a new task with a validation set for early stopping and monitoring.

ContinualLearningResult<T> LearnTask(IDataset<T, TInput, TOutput> taskData, IDataset<T, TInput, TOutput> validationData, int? earlyStoppingPatience = null)

Parameters

taskData IDataset<T, TInput, TOutput>

The training dataset for the new task.

validationData IDataset<T, TInput, TOutput>

The validation dataset for the new task.

earlyStoppingPatience int?

Number of epochs without improvement before stopping.

Returns

ContinualLearningResult<T>

Result containing training metrics and performance information.

Load(string)

Loads the learner state from a file.

void Load(string path)

Parameters

path string

Path to load the state from.

Reset()

Resets the learner to its initial state.

void Reset()

Remarks

This clears all learned knowledge, including: - Model parameters (reset to initial values) - Strategy state - Memory buffer - Task history

Save(string)

Saves the learner state to a file.

void Save(string path)

Parameters

path string

Path to save the state.

Remarks

This saves all state needed to resume training, including: - Model parameters - Strategy state (e.g., Fisher Information for EWC) - Memory buffer contents - Training history

Events

EpochCompleted

Event raised when an epoch completes during training.

event EventHandler<EpochEventArgs<T>>? EpochCompleted

Event Type

EventHandler<EpochEventArgs<T>>

TaskCompleted

Event raised when a task finishes training.

event EventHandler<TaskCompletedEventArgs<T>>? TaskCompleted

Event Type

EventHandler<TaskCompletedEventArgs<T>>

TaskStarted

Event raised when a task starts training.

event EventHandler<TaskEventArgs>? TaskStarted

Event Type

EventHandler<TaskEventArgs>