Interface IContinualLearningStrategy<T, TInput, TOutput>
- Namespace
- AiDotNet.ContinualLearning.Interfaces
- Assembly
- AiDotNet.dll
Strategy interface for continual learning algorithms.
public interface IContinualLearningStrategy<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations.
TInputThe input data type.
TOutputThe output data type.
Remarks
For Beginners: Different continual learning methods use different strategies to prevent forgetting. This interface allows the trainer to work with any strategy.
Strategy Types:
- Regularization-based: Add penalty terms to protect important weights (EWC, SI, MAS)
- Replay-based: Store and replay old examples (Experience Replay, GEM)
- Architecture-based: Use separate parameters for different tasks (Progressive Networks, PackNet)
- Distillation-based: Use teacher model to preserve old knowledge (LwF)
Reference: De Lange et al. "A Continual Learning Survey: Defying Forgetting" (2021)
Properties
MemoryUsageBytes
Gets the current memory usage of the strategy in bytes.
long MemoryUsageBytes { get; }
Property Value
ModifiesArchitecture
Gets whether this strategy modifies the model architecture.
bool ModifiesArchitecture { get; }
Property Value
Name
Gets the name of the strategy.
string Name { get; }
Property Value
RequiresMemoryBuffer
Gets whether this strategy requires storing examples from previous tasks.
bool RequiresMemoryBuffer { get; }
Property Value
Methods
AdjustGradients(Vector<T>)
Adjusts gradients to prevent forgetting.
Vector<T> AdjustGradients(Vector<T> gradients)
Parameters
gradientsVector<T>The gradients from the current task loss.
Returns
- Vector<T>
Adjusted gradients that respect previous task constraints.
Remarks
This modifies gradients before they're applied: - GEM: Project gradients to not increase loss on previous tasks - A-GEM: Project onto average gradient of reference samples - OWM: Project gradients into orthogonal subspace
ComputeRegularizationLoss(IFullModel<T, TInput, TOutput>)
Computes the regularization loss to prevent forgetting.
T ComputeRegularizationLoss(IFullModel<T, TInput, TOutput> model)
Parameters
modelIFullModel<T, TInput, TOutput>The current model.
Returns
- T
The regularization loss value.
Remarks
This is added to the task loss during training: - EWC: Sum of (current params - optimal params)^2 * Fisher Information - SI: Sum of (current params - optimal params)^2 * importance - MAS: Sum of (current params - optimal params)^2 * gradient magnitude
FinalizeTask(IFullModel<T, TInput, TOutput>)
Finalizes the task after training is complete.
void FinalizeTask(IFullModel<T, TInput, TOutput> model)
Parameters
modelIFullModel<T, TInput, TOutput>The trained model.
Remarks
This is called after training on a task completes. Strategies use this to: - EWC: Compute Fisher Information Matrix - SI: Compute path integral importance - PackNet: Prune and freeze weights
GetIncompatibilityReason(IFullModel<T, TInput, TOutput>)
Gets a description of why the strategy is incompatible with a model.
string? GetIncompatibilityReason(IFullModel<T, TInput, TOutput> model)
Parameters
modelIFullModel<T, TInput, TOutput>The model to check.
Returns
- string
Description of incompatibility, or null if compatible.
GetMetrics()
Gets strategy-specific metrics for monitoring.
IReadOnlyDictionary<string, object> GetMetrics()
Returns
- IReadOnlyDictionary<string, object>
Dictionary of metric name to value.
IsCompatibleWith(IFullModel<T, TInput, TOutput>)
Validates that the strategy is compatible with the given model.
bool IsCompatibleWith(IFullModel<T, TInput, TOutput> model)
Parameters
modelIFullModel<T, TInput, TOutput>The model to validate against.
Returns
- bool
True if compatible, false otherwise.
Load(string)
Loads the strategy state from a file.
void Load(string path)
Parameters
pathstringPath to load the state from.
PrepareForTask(IFullModel<T, TInput, TOutput>, IDataset<T, TInput, TOutput>)
Prepares the strategy for learning a new task.
void PrepareForTask(IFullModel<T, TInput, TOutput> model, IDataset<T, TInput, TOutput> taskData)
Parameters
modelIFullModel<T, TInput, TOutput>The model being trained.
taskDataIDataset<T, TInput, TOutput>The data for the new task.
Remarks
This is called at the start of training on a new task. Strategies use this to: - EWC: Store current parameters as optimal for previous tasks - LwF: Create a copy of the model as the teacher - GEM: Compute reference gradients from stored examples
Reset()
Resets the strategy to its initial state.
void Reset()
Save(string)
Saves the strategy state to a file.
void Save(string path)
Parameters
pathstringPath to save the state.