Class GradientCheckpointing<T>
Provides gradient checkpointing functionality for memory-efficient training.
public static class GradientCheckpointing<T>
Type Parameters
T
- Inheritance
-
GradientCheckpointing<T>
- Inherited Members
Remarks
Gradient checkpointing (also known as activation checkpointing or memory checkpointing) is a technique that trades computation time for memory by not storing all intermediate activations during the forward pass. Instead, it recomputes them during the backward pass.
For Beginners: When training large neural networks, storing all intermediate results (activations) can use a lot of memory. Gradient checkpointing saves memory by:
- Only storing activations at certain "checkpoints"
- During backpropagation, recomputing the activations between checkpoints
This uses less memory but takes more time (roughly 30% more computation). It's essential for training very large models that wouldn't otherwise fit in GPU memory.
This implementation follows patterns from PyTorch's torch.utils.checkpoint and TensorFlow's tf.recompute_grad.
Methods
Checkpoint(Func<ComputationNode<T>>, IEnumerable<ComputationNode<T>>)
Executes a function with gradient checkpointing.
public static ComputationNode<T> Checkpoint(Func<ComputationNode<T>> function, IEnumerable<ComputationNode<T>> inputs)
Parameters
functionFunc<ComputationNode<T>>The function to execute with checkpointing.
inputsIEnumerable<ComputationNode<T>>The input nodes to the function.
Returns
- ComputationNode<T>
The output node from the function.
Remarks
The function will be executed during the forward pass, but its intermediate activations will not be saved. During the backward pass, the function will be re-executed to recompute the needed activations.
For Beginners: Wrap parts of your model in this function to save memory:
// Without checkpointing (uses more memory):
var output = layer1.Forward(input);
output = layer2.Forward(output);
// With checkpointing (uses less memory):
var output = GradientCheckpointing<float>.Checkpoint(
() => {
var x = layer1.Forward(input);
return layer2.Forward(x);
},
new[] { input }
);
CheckpointMultiOutput(Func<IReadOnlyList<ComputationNode<T>>>, IEnumerable<ComputationNode<T>>)
Executes a function with gradient checkpointing, supporting multiple outputs.
public static IReadOnlyList<ComputationNode<T>> CheckpointMultiOutput(Func<IReadOnlyList<ComputationNode<T>>> function, IEnumerable<ComputationNode<T>> inputs)
Parameters
functionFunc<IReadOnlyList<ComputationNode<T>>>The function to execute with checkpointing.
inputsIEnumerable<ComputationNode<T>>The input nodes to the function.
Returns
- IReadOnlyList<ComputationNode<T>>
The output nodes from the function.
EstimateMemorySavings(int, long, int)
Estimates memory savings from using gradient checkpointing.
public static (long WithoutCheckpoint, long WithCheckpoint, double SavingsPercent) EstimateMemorySavings(int numLayers, long activationSize, int segmentSize = 2)
Parameters
numLayersintNumber of layers in the model.
activationSizelongSize of activations per layer in bytes.
segmentSizeintNumber of layers per checkpoint segment.
Returns
- (long WithoutCheckpoint, long WithCheckpoint, double SavingsPercent)
A tuple of (memory without checkpointing, memory with checkpointing, savings percentage).
Remarks
For Beginners: This helps you estimate how much memory you'll save:
var (without, with, savings) = GradientCheckpointing<float>.EstimateMemorySavings(
numLayers: 24,
activationSize: 100_000_000, // 100MB per layer
segmentSize: 4
);
Console.WriteLine($"Saves {savings:P1} memory");
SequentialCheckpoint(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>, ComputationNode<T>, int)
Creates a sequential checkpoint that divides a sequence of layers into segments.
public static ComputationNode<T> SequentialCheckpoint(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>> layers, ComputationNode<T> input, int segmentSize = 2)
Parameters
layersIReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>The sequence of layer functions to checkpoint.
inputComputationNode<T>The input to the first layer.
segmentSizeintNumber of layers per checkpoint segment. Default: 2
Returns
- ComputationNode<T>
The output from the final layer.
Remarks
This is a convenience method for checkpointing sequential models. It automatically divides the layers into segments and applies checkpointing to each segment.
For Beginners: For models with many sequential layers (like ResNet or Transformers), this automatically applies checkpointing efficiently:
var layers = new List<Func<ComputationNode<float>, ComputationNode<float>>>
{
x => layer1.Forward(x),
x => layer2.Forward(x),
x => layer3.Forward(x),
x => layer4.Forward(x)
};
// Checkpoint every 2 layers
var output = GradientCheckpointing<float>.SequentialCheckpoint(layers, input, segmentSize: 2);