Table of Contents

Class TrainingMemoryManager<T>

Namespace
AiDotNet.Training.Memory
Assembly
AiDotNet.dll

Manages memory optimization during neural network training including gradient checkpointing, activation pooling, and model sharding.

public class TrainingMemoryManager<T> : IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
TrainingMemoryManager<T>
Implements
Inherited Members

Remarks

For Beginners: This manager helps you train larger neural networks by:

  1. Gradient Checkpointing: Saves memory by recomputing activations during backward pass
  2. Activation Pooling: Reuses tensor memory to reduce garbage collection
  3. Model Sharding: Distributes layers across multiple GPUs

Example usage:

var config = TrainingMemoryConfig.MemoryEfficient();
var memoryManager = new TrainingMemoryManager<float>(config, network.Layers);

// During training
foreach (var layer in layers)
{
    if (memoryManager.ShouldCheckpoint(layerIndex))
    {
        output = memoryManager.ForwardWithCheckpoint(layer, input);
    }
    else
    {
        output = layer.Forward(input);
    }
}

Constructors

TrainingMemoryManager(TrainingMemoryConfig?, IEnumerable<ILayer<T>>?)

Initializes a new instance of the TrainingMemoryManager.

public TrainingMemoryManager(TrainingMemoryConfig? config = null, IEnumerable<ILayer<T>>? layers = null)

Parameters

config TrainingMemoryConfig

Memory configuration.

layers IEnumerable<ILayer<T>>

Optional layers for model sharding.

Properties

Config

Gets the memory configuration.

public TrainingMemoryConfig Config { get; }

Property Value

TrainingMemoryConfig

IsCheckpointingEnabled

Gets whether gradient checkpointing is enabled.

public bool IsCheckpointingEnabled { get; }

Property Value

bool

IsPoolingEnabled

Gets whether activation pooling is enabled.

public bool IsPoolingEnabled { get; }

Property Value

bool

IsShardingEnabled

Gets whether model sharding is enabled.

public bool IsShardingEnabled { get; }

Property Value

bool

PoolStats

Gets pool statistics if activation pooling is enabled.

public ActivationPoolStats? PoolStats { get; }

Property Value

ActivationPoolStats

Methods

BackwardSequence(IReadOnlyList<ILayer<T>>, Tensor<T>)

Performs backward pass through multiple layers with recomputation.

public Tensor<T> BackwardSequence(IReadOnlyList<ILayer<T>> layers, Tensor<T> outputGradient)

Parameters

layers IReadOnlyList<ILayer<T>>

Layers in forward order.

outputGradient Tensor<T>

Final output gradient.

Returns

Tensor<T>

Gradient with respect to initial input.

BackwardWithRecompute(ILayer<T>, Tensor<T>, int)

Performs backward pass, recomputing activations from checkpoints.

public Tensor<T> BackwardWithRecompute(ILayer<T> layer, Tensor<T> outputGradient, int layerIndex)

Parameters

layer ILayer<T>

The layer to backpropagate through.

outputGradient Tensor<T>

Gradient from the next layer.

layerIndex int

Index of this layer.

Returns

Tensor<T>

Gradient with respect to input.

ClearCheckpoints()

Clears all stored checkpoints to free memory.

public void ClearCheckpoints()

ComputeCheckpointIndices(int, IReadOnlyList<string>?)

Determines which layers should be checkpointed based on configuration.

public void ComputeCheckpointIndices(int totalLayers, IReadOnlyList<string>? layerTypes = null)

Parameters

totalLayers int

Total number of layers in the network.

layerTypes IReadOnlyList<string>

Optional list of layer type names for smart checkpointing.

Dispose()

Disposes resources used by the memory manager.

public void Dispose()

EstimateMemorySavings(long, int, int)

Estimates memory savings from current configuration.

public MemorySavingsEstimate EstimateMemorySavings(long modelParameters, int batchSize, int sequenceLength = 512)

Parameters

modelParameters long

Total number of model parameters.

batchSize int

Training batch size.

sequenceLength int

Sequence length (for transformers).

Returns

MemorySavingsEstimate

Estimated memory savings information.

ForwardSequence(IEnumerable<ILayer<T>>, Tensor<T>)

Performs forward pass through multiple layers with checkpointing.

public Tensor<T> ForwardSequence(IEnumerable<ILayer<T>> layers, Tensor<T> input)

Parameters

layers IEnumerable<ILayer<T>>

Sequence of layers to execute.

input Tensor<T>

Initial input tensor.

Returns

Tensor<T>

Output tensor from the final layer.

ForwardWithCheckpoint(ILayer<T>, Tensor<T>, int)

Performs a forward pass with checkpointing for a single layer.

public Tensor<T> ForwardWithCheckpoint(ILayer<T> layer, Tensor<T> input, int layerIndex)

Parameters

layer ILayer<T>

The layer to execute.

input Tensor<T>

Input tensor.

layerIndex int

Index of this layer (for checkpoint storage).

Returns

Tensor<T>

Output tensor from the layer.

GetPoolMemoryUsage()

Gets current memory usage from the activation pool.

public long GetPoolMemoryUsage()

Returns

long

RentTensor(int[])

Rents a tensor from the activation pool.

public Tensor<T> RentTensor(int[] shape)

Parameters

shape int[]

Desired tensor shape.

Returns

Tensor<T>

A tensor (may contain uninitialized data).

ReturnTensor(Tensor<T>)

Returns a tensor to the activation pool.

public void ReturnTensor(Tensor<T> tensor)

Parameters

tensor Tensor<T>

Tensor to return.

ShardedBackward(Tensor<T>)

Performs backward pass through sharded model.

public Tensor<T> ShardedBackward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Gradient from loss.

Returns

Tensor<T>

Gradient with respect to input.

ShardedForward(Tensor<T>)

Performs forward pass through sharded model.

public Tensor<T> ShardedForward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor.

Returns

Tensor<T>

Output tensor.

ShouldCheckpoint(int)

Determines if a specific layer should be checkpointed.

public bool ShouldCheckpoint(int layerIndex)

Parameters

layerIndex int

Index of the layer.

Returns

bool

True if the layer should be checkpointed.