Class LearningWithoutForgetting<T>
- Namespace
- AiDotNet.ContinualLearning
- Assembly
- AiDotNet.dll
Implements Learning without Forgetting (LwF) for continual learning.
public class LearningWithoutForgetting<T> : IContinualLearningStrategy<T>
Type Parameters
TThe numeric type for calculations.
- Inheritance
-
LearningWithoutForgetting<T>
- Implements
- Inherited Members
Remarks
For Beginners: Learning without Forgetting is like teaching a student to solve new problems while making sure they remember how to solve old ones. It does this by asking the student to match their old answers (before learning new material) even as they learn new things.
How it works:
- Before learning a new task, LwF records the network's predictions (soft targets) on the new task's inputs using the old model.
- During training on the new task, the loss function includes both: the regular task loss AND a distillation loss that encourages matching the old predictions.
- The distillation loss uses temperature-scaled softmax to capture the relationships between classes, not just the predicted class.
Reference: Li and Hoiem, "Learning without Forgetting" (2017). IEEE TPAMI.
Constructors
LearningWithoutForgetting(double, double)
Initializes a new instance of the LearningWithoutForgetting class.
public LearningWithoutForgetting(double lambda = 1, double temperature = 2)
Parameters
lambdadoubleThe weight of the distillation loss (default: 1.0).
temperaturedoubleTemperature for softmax distillation (default: 2.0).
Remarks
For Beginners:
- Lambda: How much to weight the "don't forget" loss compared to the "learn new task" loss. Higher values preserve more old knowledge.
- Temperature: Controls how "soft" the predictions are. Higher temperature makes the probability distribution smoother, which helps transfer more nuanced knowledge from the old model to the new one.
Properties
Lambda
Gets the regularization strength parameter (lambda) for loss-based continual learning.
public double Lambda { get; set; }
Property Value
Remarks
For Beginners: Lambda controls how strongly the strategy prevents forgetting. A higher lambda means the network is more conservative about changing weights important for previous tasks, but this might make it harder to learn new tasks effectively.
Typical values range from 100 to 10000, depending on the complexity of tasks and how important it is to preserve old knowledge versus learning new knowledge.
TaskCount
Gets the number of tasks that have stored predictions.
public int TaskCount { get; }
Property Value
Temperature
Gets or sets the temperature for knowledge distillation.
public double Temperature { get; set; }
Property Value
Remarks
For Beginners: Temperature controls how "soft" the probability distribution becomes:
- T = 1: Normal softmax (sharp, peaked distribution)
- T = 2-5: Softer distribution that reveals class relationships
- T > 5: Very soft, almost uniform distribution
Typical values for LwF are 2-4.
Methods
AfterTask(INeuralNetwork<T>, (Tensor<T> inputs, Tensor<T> targets), int)
Processes information after completing training on a task.
public void AfterTask(INeuralNetwork<T> network, (Tensor<T> inputs, Tensor<T> targets) taskData, int taskId)
Parameters
networkINeuralNetwork<T>The neural network that was trained.
taskData(Tensor<T> grad1, Tensor<T> grad2)Data from the completed task for computing importance measures.
taskIdintThe identifier for the completed task.
Remarks
For Beginners: This method is called after you finish training on a task. It allows the strategy to compute and store information about what the network learned, which will be used to protect this knowledge when learning future tasks.
For example, in Elastic Weight Consolidation (EWC), this computes the Fisher Information Matrix to identify which weights are most important for the completed task.
BeforeTask(INeuralNetwork<T>, int)
Prepares the strategy before starting to learn a new task.
public void BeforeTask(INeuralNetwork<T> network, int taskId)
Parameters
networkINeuralNetwork<T>The neural network that will be trained.
taskIdintThe identifier for the upcoming task (0-indexed).
Remarks
For Beginners: This method is called before you start training on a new task. It allows the strategy to capture the network's current state or prepare any necessary data structures for protecting knowledge from previous tasks.
For example, in Learning without Forgetting (LwF), this might store the network's predictions on the new task's inputs before training begins, so we can later encourage the network to maintain similar predictions.
ComputeDistillationLoss(Tensor<T>, Tensor<T>)
Computes the distillation loss between current and old predictions.
public T ComputeDistillationLoss(Tensor<T> currentPredictions, Tensor<T> oldPredictions)
Parameters
currentPredictionsTensor<T>The network's current predictions.
oldPredictionsTensor<T>The network's predictions before learning new tasks.
Returns
- T
The distillation loss (KL divergence with temperature scaling).
ComputeLoss(INeuralNetwork<T>)
Computes the regularization loss to prevent forgetting previous tasks.
public T ComputeLoss(INeuralNetwork<T> network)
Parameters
networkINeuralNetwork<T>The neural network being trained.
Returns
- T
The regularization loss value that should be added to the task loss.
Remarks
For Beginners: This method calculates an additional loss term that penalizes the network for deviating from its learned knowledge of previous tasks. You add this to your regular task loss during training:
var totalLoss = taskLoss + strategy.ComputeLoss(network);
For example, in EWC, this returns a penalty proportional to how much important weights have changed from their optimal values for previous tasks. Larger changes to important weights result in higher loss, discouraging the network from forgetting.
ModifyGradients(INeuralNetwork<T>, Vector<T>)
Modifies the gradient to prevent catastrophic forgetting.
public Vector<T> ModifyGradients(INeuralNetwork<T> network, Vector<T> gradients)
Parameters
networkINeuralNetwork<T>The neural network being trained.
gradientsVector<T>The gradients from the current task loss.
Returns
- Vector<T>
Modified gradients that protect previous task knowledge.
Remarks
For Beginners: Some continual learning strategies work by modifying the gradients (the update directions for weights) rather than adding a loss term. This method takes the gradients computed from the current task and modifies them to avoid interfering with previously learned tasks.
For example, in Gradient Episodic Memory (GEM), if a gradient would hurt performance on stored examples from previous tasks, it's projected to the closest gradient that doesn't interfere with those examples.
If a strategy doesn't use gradient modification, this should return the gradients unchanged.
PrepareDistillation(INeuralNetwork<T>, Tensor<T>, int)
Prepares distillation by recording the old model's predictions on new task inputs.
public void PrepareDistillation(INeuralNetwork<T> network, Tensor<T> newTaskInputs, int taskId)
Parameters
networkINeuralNetwork<T>The neural network before training on the new task.
newTaskInputsTensor<T>The inputs for the new task.
taskIdintThe identifier for the new task.
Remarks
For Beginners: Call this method before training on a new task. It records what the network currently predicts for the new task's inputs. These predictions become the "soft targets" that the network tries to match during training, preventing it from forgetting its old behavior.
Reset()
Resets the strategy, clearing all stored task information.
public void Reset()
Remarks
For Beginners: This method clears all the information the strategy has accumulated about previous tasks. After calling this, the network will be free to learn new tasks without any constraints from previously learned tasks.
Use this when you want to start fresh or when you're done with a sequence of tasks and want to begin a new independent sequence.