Class TrainingMemoryManager<T>
Manages memory optimization during neural network training including gradient checkpointing, activation pooling, and model sharding.
public class TrainingMemoryManager<T> : IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
TrainingMemoryManager<T>
- Implements
- Inherited Members
Remarks
For Beginners: This manager helps you train larger neural networks by:
- Gradient Checkpointing: Saves memory by recomputing activations during backward pass
- Activation Pooling: Reuses tensor memory to reduce garbage collection
- Model Sharding: Distributes layers across multiple GPUs
Example usage:
var config = TrainingMemoryConfig.MemoryEfficient();
var memoryManager = new TrainingMemoryManager<float>(config, network.Layers);
// During training
foreach (var layer in layers)
{
if (memoryManager.ShouldCheckpoint(layerIndex))
{
output = memoryManager.ForwardWithCheckpoint(layer, input);
}
else
{
output = layer.Forward(input);
}
}
Constructors
TrainingMemoryManager(TrainingMemoryConfig?, IEnumerable<ILayer<T>>?)
Initializes a new instance of the TrainingMemoryManager.
public TrainingMemoryManager(TrainingMemoryConfig? config = null, IEnumerable<ILayer<T>>? layers = null)
Parameters
configTrainingMemoryConfigMemory configuration.
layersIEnumerable<ILayer<T>>Optional layers for model sharding.
Properties
Config
Gets the memory configuration.
public TrainingMemoryConfig Config { get; }
Property Value
IsCheckpointingEnabled
Gets whether gradient checkpointing is enabled.
public bool IsCheckpointingEnabled { get; }
Property Value
IsPoolingEnabled
Gets whether activation pooling is enabled.
public bool IsPoolingEnabled { get; }
Property Value
IsShardingEnabled
Gets whether model sharding is enabled.
public bool IsShardingEnabled { get; }
Property Value
PoolStats
Gets pool statistics if activation pooling is enabled.
public ActivationPoolStats? PoolStats { get; }
Property Value
Methods
BackwardSequence(IReadOnlyList<ILayer<T>>, Tensor<T>)
Performs backward pass through multiple layers with recomputation.
public Tensor<T> BackwardSequence(IReadOnlyList<ILayer<T>> layers, Tensor<T> outputGradient)
Parameters
layersIReadOnlyList<ILayer<T>>Layers in forward order.
outputGradientTensor<T>Final output gradient.
Returns
- Tensor<T>
Gradient with respect to initial input.
BackwardWithRecompute(ILayer<T>, Tensor<T>, int)
Performs backward pass, recomputing activations from checkpoints.
public Tensor<T> BackwardWithRecompute(ILayer<T> layer, Tensor<T> outputGradient, int layerIndex)
Parameters
layerILayer<T>The layer to backpropagate through.
outputGradientTensor<T>Gradient from the next layer.
layerIndexintIndex of this layer.
Returns
- Tensor<T>
Gradient with respect to input.
ClearCheckpoints()
Clears all stored checkpoints to free memory.
public void ClearCheckpoints()
ComputeCheckpointIndices(int, IReadOnlyList<string>?)
Determines which layers should be checkpointed based on configuration.
public void ComputeCheckpointIndices(int totalLayers, IReadOnlyList<string>? layerTypes = null)
Parameters
totalLayersintTotal number of layers in the network.
layerTypesIReadOnlyList<string>Optional list of layer type names for smart checkpointing.
Dispose()
Disposes resources used by the memory manager.
public void Dispose()
EstimateMemorySavings(long, int, int)
Estimates memory savings from current configuration.
public MemorySavingsEstimate EstimateMemorySavings(long modelParameters, int batchSize, int sequenceLength = 512)
Parameters
modelParameterslongTotal number of model parameters.
batchSizeintTraining batch size.
sequenceLengthintSequence length (for transformers).
Returns
- MemorySavingsEstimate
Estimated memory savings information.
ForwardSequence(IEnumerable<ILayer<T>>, Tensor<T>)
Performs forward pass through multiple layers with checkpointing.
public Tensor<T> ForwardSequence(IEnumerable<ILayer<T>> layers, Tensor<T> input)
Parameters
layersIEnumerable<ILayer<T>>Sequence of layers to execute.
inputTensor<T>Initial input tensor.
Returns
- Tensor<T>
Output tensor from the final layer.
ForwardWithCheckpoint(ILayer<T>, Tensor<T>, int)
Performs a forward pass with checkpointing for a single layer.
public Tensor<T> ForwardWithCheckpoint(ILayer<T> layer, Tensor<T> input, int layerIndex)
Parameters
layerILayer<T>The layer to execute.
inputTensor<T>Input tensor.
layerIndexintIndex of this layer (for checkpoint storage).
Returns
- Tensor<T>
Output tensor from the layer.
GetPoolMemoryUsage()
Gets current memory usage from the activation pool.
public long GetPoolMemoryUsage()
Returns
RentTensor(int[])
Rents a tensor from the activation pool.
public Tensor<T> RentTensor(int[] shape)
Parameters
shapeint[]Desired tensor shape.
Returns
- Tensor<T>
A tensor (may contain uninitialized data).
ReturnTensor(Tensor<T>)
Returns a tensor to the activation pool.
public void ReturnTensor(Tensor<T> tensor)
Parameters
tensorTensor<T>Tensor to return.
ShardedBackward(Tensor<T>)
Performs backward pass through sharded model.
public Tensor<T> ShardedBackward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient from loss.
Returns
- Tensor<T>
Gradient with respect to input.
ShardedForward(Tensor<T>)
Performs forward pass through sharded model.
public Tensor<T> ShardedForward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor.
Returns
- Tensor<T>
Output tensor.
ShouldCheckpoint(int)
Determines if a specific layer should be checkpointed.
public bool ShouldCheckpoint(int layerIndex)
Parameters
layerIndexintIndex of the layer.
Returns
- bool
True if the layer should be checkpointed.