Interface IContinualLearner<T, TInput, TOutput>
- Namespace
- AiDotNet.ContinualLearning.Interfaces
- Assembly
- AiDotNet.dll
Interface for continual learning trainers that can learn multiple tasks sequentially.
public interface IContinualLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations.
TInputThe input data type.
TOutputThe output data type.
Remarks
For Beginners: Continual learning (also called lifelong learning) is the ability to learn new tasks over time without forgetting previously learned knowledge. Traditional neural networks suffer from "catastrophic forgetting" - when trained on new data, they forget what they learned before.
This interface provides methods to:
- Learn new tasks while protecting old knowledge
- Evaluate performance on all learned tasks
- Save and load model state for resuming training
Common Implementations:
- EWCTrainer: Uses Elastic Weight Consolidation to protect important weights
- LwFTrainer: Uses Learning without Forgetting with knowledge distillation
- GEMTrainer: Uses Gradient Episodic Memory to constrain gradients
References:
- Parisi et al. "Continual Lifelong Learning with Neural Networks: A Review" (2019)
- De Lange et al. "A Continual Learning Survey" (2021)
Properties
BaseModel
Gets the underlying model being trained.
IFullModel<T, TInput, TOutput> BaseModel { get; }
Property Value
- IFullModel<T, TInput, TOutput>
Config
Gets the configuration for the continual learner.
IContinualLearnerConfig<T> Config { get; }
Property Value
IsTraining
Gets whether the learner is currently training.
bool IsTraining { get; }
Property Value
MemoryUsageBytes
Gets the current memory usage of the learner in bytes.
long MemoryUsageBytes { get; }
Property Value
TasksLearned
Gets the number of tasks that have been learned.
int TasksLearned { get; }
Property Value
Methods
ComputeForgetting()
Computes the current forgetting metric for all tasks.
IReadOnlyDictionary<int, T> ComputeForgetting()
Returns
- IReadOnlyDictionary<int, T>
Dictionary mapping task ID to forgetting amount.
EvaluateAllTasks()
Evaluates the model on all learned tasks.
ContinualEvaluationResult<T> EvaluateAllTasks()
Returns
- ContinualEvaluationResult<T>
Comprehensive evaluation result including backward/forward transfer.
EvaluateTask(int, IDataset<T, TInput, TOutput>)
Evaluates the model on a specific task.
TaskEvaluationResult<T> EvaluateTask(int taskId, IDataset<T, TInput, TOutput> testData)
Parameters
taskIdintThe task identifier (0-indexed).
testDataIDataset<T, TInput, TOutput>The test data for the task.
Returns
- TaskEvaluationResult<T>
Evaluation result for the specified task.
GetAllHistory()
Gets all training history.
IReadOnlyList<ContinualLearningResult<T>> GetAllHistory()
Returns
- IReadOnlyList<ContinualLearningResult<T>>
List of all training results.
GetTaskHistory(int)
Gets the training history for a specific task.
ContinualLearningResult<T>? GetTaskHistory(int taskId)
Parameters
taskIdintThe task identifier.
Returns
- ContinualLearningResult<T>
The training result for the task, or null if not available.
LearnTask(IDataset<T, TInput, TOutput>)
Learns a new task from the provided data.
ContinualLearningResult<T> LearnTask(IDataset<T, TInput, TOutput> taskData)
Parameters
taskDataIDataset<T, TInput, TOutput>The dataset for the new task.
Returns
- ContinualLearningResult<T>
Result containing training metrics and performance information.
Remarks
This method trains the model on the new task while using the configured strategy to prevent forgetting of previously learned tasks.
LearnTask(IDataset<T, TInput, TOutput>, IDataset<T, TInput, TOutput>, int?)
Learns a new task with a validation set for early stopping and monitoring.
ContinualLearningResult<T> LearnTask(IDataset<T, TInput, TOutput> taskData, IDataset<T, TInput, TOutput> validationData, int? earlyStoppingPatience = null)
Parameters
taskDataIDataset<T, TInput, TOutput>The training dataset for the new task.
validationDataIDataset<T, TInput, TOutput>The validation dataset for the new task.
earlyStoppingPatienceint?Number of epochs without improvement before stopping.
Returns
- ContinualLearningResult<T>
Result containing training metrics and performance information.
Load(string)
Loads the learner state from a file.
void Load(string path)
Parameters
pathstringPath to load the state from.
Reset()
Resets the learner to its initial state.
void Reset()
Remarks
This clears all learned knowledge, including: - Model parameters (reset to initial values) - Strategy state - Memory buffer - Task history
Save(string)
Saves the learner state to a file.
void Save(string path)
Parameters
pathstringPath to save the state.
Remarks
This saves all state needed to resume training, including: - Model parameters - Strategy state (e.g., Fisher Information for EWC) - Memory buffer contents - Training history
Events
EpochCompleted
Event raised when an epoch completes during training.
event EventHandler<EpochEventArgs<T>>? EpochCompleted
Event Type
TaskCompleted
Event raised when a task finishes training.
event EventHandler<TaskCompletedEventArgs<T>>? TaskCompleted
Event Type
TaskStarted
Event raised when a task starts training.
event EventHandler<TaskEventArgs>? TaskStarted