Table of Contents

Class ProgressiveNeuralNetworks<T>

Namespace
AiDotNet.ContinualLearning
Assembly
AiDotNet.dll

Implements Progressive Neural Networks for continual learning.

public class ProgressiveNeuralNetworks<T> : IContinualLearningStrategy<T>

Type Parameters

T

The numeric type for calculations.

Inheritance
ProgressiveNeuralNetworks<T>
Implements
Inherited Members

Remarks

For Beginners: Progressive Neural Networks prevent forgetting by freezing previously learned networks and adding new "columns" (networks) for each new task. The new columns can receive input from all previous columns through lateral connections, enabling knowledge transfer without forgetting.

How it works:

  1. Train a neural network column for Task 1.
  2. Freeze the Task 1 column completely.
  3. Add a new column for Task 2 with lateral connections from Task 1's hidden layers.
  4. Train only the new column (Task 1 column remains frozen).
  5. Repeat for each new task, adding lateral connections from all previous columns.

Architecture:

Task 1      Task 2      Task 3
Column      Column      Column
  │           │           │
[L1]──────>[L1]──────>[L1]    (Lateral connections)
  │           │           │
[L2]──────>[L2]──────>[L2]
  │           │           │
[Out]       [Out]       [Out]

Advantages:

  • Zero forgetting - previous columns are completely frozen.
  • Positive transfer - new tasks can leverage previous knowledge.
  • Clear task separation - each task has its own output.

Disadvantages:

  • Linear growth in parameters with number of tasks.
  • Memory usage increases with each task.

Reference: Rusu, A.A. et al. "Progressive Neural Networks" (2016). arXiv.

Constructors

ProgressiveNeuralNetworks(bool, double)

Initializes a new instance of the ProgressiveNeuralNetworks class.

public ProgressiveNeuralNetworks(bool useLateralConnections = true, double lambda = 1)

Parameters

useLateralConnections bool

Whether to use lateral connections between columns (default: true).

lambda double

Regularization strength for lateral connections (default: 1.0).

Remarks

For Beginners:

  • Lateral connections allow new tasks to use features learned by previous tasks.
  • Without lateral connections, this becomes simple multi-head training.

Properties

ColumnCount

Gets the number of columns (tasks) in the progressive network.

public int ColumnCount { get; }

Property Value

int

Lambda

Gets the regularization strength parameter (lambda) for loss-based continual learning.

public double Lambda { get; set; }

Property Value

double

Remarks

For Beginners: Lambda controls how strongly the strategy prevents forgetting. A higher lambda means the network is more conservative about changing weights important for previous tasks, but this might make it harder to learn new tasks effectively.

Typical values range from 100 to 10000, depending on the complexity of tasks and how important it is to preserve old knowledge versus learning new knowledge.

UseLateralConnections

Gets whether lateral connections are enabled.

public bool UseLateralConnections { get; }

Property Value

bool

Methods

AfterTask(INeuralNetwork<T>, (Tensor<T> inputs, Tensor<T> targets), int)

Processes information after completing training on a task.

public void AfterTask(INeuralNetwork<T> network, (Tensor<T> inputs, Tensor<T> targets) taskData, int taskId)

Parameters

network INeuralNetwork<T>

The neural network that was trained.

taskData (Tensor<T> grad1, Tensor<T> grad2)

Data from the completed task for computing importance measures.

taskId int

The identifier for the completed task.

Remarks

For Beginners: This method is called after you finish training on a task. It allows the strategy to compute and store information about what the network learned, which will be used to protect this knowledge when learning future tasks.

For example, in Elastic Weight Consolidation (EWC), this computes the Fisher Information Matrix to identify which weights are most important for the completed task.

BeforeTask(INeuralNetwork<T>, int)

Prepares the strategy before starting to learn a new task.

public void BeforeTask(INeuralNetwork<T> network, int taskId)

Parameters

network INeuralNetwork<T>

The neural network that will be trained.

taskId int

The identifier for the upcoming task (0-indexed).

Remarks

For Beginners: This method is called before you start training on a new task. It allows the strategy to capture the network's current state or prepare any necessary data structures for protecting knowledge from previous tasks.

For example, in Learning without Forgetting (LwF), this might store the network's predictions on the new task's inputs before training begins, so we can later encourage the network to maintain similar predictions.

ComputeLateralInput(List<Tensor<T>>, List<Tensor<T>>)

Computes the lateral connection activation from previous columns.

public Tensor<T> ComputeLateralInput(List<Tensor<T>> previousActivations, List<Tensor<T>> lateralWeights)

Parameters

previousActivations List<Tensor<T>>

Activations from previous columns at a given layer.

lateralWeights List<Tensor<T>>

Lateral connection weights.

Returns

Tensor<T>

Combined lateral input for the current column.

Remarks

For Beginners: Lateral connections take the hidden activations from all previous columns and combine them (weighted sum) as additional input to the current column's layers. This allows knowledge transfer.

ComputeLoss(INeuralNetwork<T>)

Computes the regularization loss to prevent forgetting previous tasks.

public T ComputeLoss(INeuralNetwork<T> network)

Parameters

network INeuralNetwork<T>

The neural network being trained.

Returns

T

The regularization loss value that should be added to the task loss.

Remarks

For Beginners: This method calculates an additional loss term that penalizes the network for deviating from its learned knowledge of previous tasks. You add this to your regular task loss during training:

var totalLoss = taskLoss + strategy.ComputeLoss(network);

For example, in EWC, this returns a penalty proportional to how much important weights have changed from their optimal values for previous tasks. Larger changes to important weights result in higher loss, discouraging the network from forgetting.

EstimateMemoryUsage()

Estimates the total memory usage of the progressive network.

public long EstimateMemoryUsage()

Returns

long

Estimated memory in bytes (assuming 4 bytes per parameter for float).

GetColumnParameters(int)

Gets parameters for a specific task's column.

public Vector<T>? GetColumnParameters(int taskId)

Parameters

taskId int

The task ID.

Returns

Vector<T>

The frozen parameters for that task, or null if not found.

GetNetworkStats()

Gets statistics about the progressive network structure.

public Dictionary<string, object> GetNetworkStats()

Returns

Dictionary<string, object>

Dictionary with network statistics.

ModifyGradients(INeuralNetwork<T>, Vector<T>)

Modifies the gradient to prevent catastrophic forgetting.

public Vector<T> ModifyGradients(INeuralNetwork<T> network, Vector<T> gradients)

Parameters

network INeuralNetwork<T>

The neural network being trained.

gradients Vector<T>

The gradients from the current task loss.

Returns

Vector<T>

Modified gradients that protect previous task knowledge.

Remarks

For Beginners: Some continual learning strategies work by modifying the gradients (the update directions for weights) rather than adding a loss term. This method takes the gradients computed from the current task and modifies them to avoid interfering with previously learned tasks.

For example, in Gradient Episodic Memory (GEM), if a gradient would hurt performance on stored examples from previous tasks, it's projected to the closest gradient that doesn't interfere with those examples.

If a strategy doesn't use gradient modification, this should return the gradients unchanged.

Reset()

Resets the strategy, clearing all stored task information.

public void Reset()

Remarks

For Beginners: This method clears all the information the strategy has accumulated about previous tasks. After calling this, the network will be free to learn new tasks without any constraints from previously learned tasks.

Use this when you want to start fresh or when you're done with a sequence of tasks and want to begin a new independent sequence.