Table of Contents

Class GradientCheckpointing<T>

Namespace
AiDotNet.Autodiff
Assembly
AiDotNet.dll

Provides gradient checkpointing functionality for memory-efficient training.

public static class GradientCheckpointing<T>

Type Parameters

T
Inheritance
GradientCheckpointing<T>
Inherited Members

Remarks

Gradient checkpointing (also known as activation checkpointing or memory checkpointing) is a technique that trades computation time for memory by not storing all intermediate activations during the forward pass. Instead, it recomputes them during the backward pass.

For Beginners: When training large neural networks, storing all intermediate results (activations) can use a lot of memory. Gradient checkpointing saves memory by:

  1. Only storing activations at certain "checkpoints"
  2. During backpropagation, recomputing the activations between checkpoints

This uses less memory but takes more time (roughly 30% more computation). It's essential for training very large models that wouldn't otherwise fit in GPU memory.

This implementation follows patterns from PyTorch's torch.utils.checkpoint and TensorFlow's tf.recompute_grad.

Methods

Checkpoint(Func<ComputationNode<T>>, IEnumerable<ComputationNode<T>>)

Executes a function with gradient checkpointing.

public static ComputationNode<T> Checkpoint(Func<ComputationNode<T>> function, IEnumerable<ComputationNode<T>> inputs)

Parameters

function Func<ComputationNode<T>>

The function to execute with checkpointing.

inputs IEnumerable<ComputationNode<T>>

The input nodes to the function.

Returns

ComputationNode<T>

The output node from the function.

Remarks

The function will be executed during the forward pass, but its intermediate activations will not be saved. During the backward pass, the function will be re-executed to recompute the needed activations.

For Beginners: Wrap parts of your model in this function to save memory:

// Without checkpointing (uses more memory):
var output = layer1.Forward(input);
output = layer2.Forward(output);

// With checkpointing (uses less memory):
var output = GradientCheckpointing<float>.Checkpoint(
    () => {
        var x = layer1.Forward(input);
        return layer2.Forward(x);
    },
    new[] { input }
);

CheckpointMultiOutput(Func<IReadOnlyList<ComputationNode<T>>>, IEnumerable<ComputationNode<T>>)

Executes a function with gradient checkpointing, supporting multiple outputs.

public static IReadOnlyList<ComputationNode<T>> CheckpointMultiOutput(Func<IReadOnlyList<ComputationNode<T>>> function, IEnumerable<ComputationNode<T>> inputs)

Parameters

function Func<IReadOnlyList<ComputationNode<T>>>

The function to execute with checkpointing.

inputs IEnumerable<ComputationNode<T>>

The input nodes to the function.

Returns

IReadOnlyList<ComputationNode<T>>

The output nodes from the function.

EstimateMemorySavings(int, long, int)

Estimates memory savings from using gradient checkpointing.

public static (long WithoutCheckpoint, long WithCheckpoint, double SavingsPercent) EstimateMemorySavings(int numLayers, long activationSize, int segmentSize = 2)

Parameters

numLayers int

Number of layers in the model.

activationSize long

Size of activations per layer in bytes.

segmentSize int

Number of layers per checkpoint segment.

Returns

(long WithoutCheckpoint, long WithCheckpoint, double SavingsPercent)

A tuple of (memory without checkpointing, memory with checkpointing, savings percentage).

Remarks

For Beginners: This helps you estimate how much memory you'll save:

var (without, with, savings) = GradientCheckpointing<float>.EstimateMemorySavings(
    numLayers: 24,
    activationSize: 100_000_000,  // 100MB per layer
    segmentSize: 4
);
Console.WriteLine($"Saves {savings:P1} memory");

SequentialCheckpoint(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>, ComputationNode<T>, int)

Creates a sequential checkpoint that divides a sequence of layers into segments.

public static ComputationNode<T> SequentialCheckpoint(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>> layers, ComputationNode<T> input, int segmentSize = 2)

Parameters

layers IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>

The sequence of layer functions to checkpoint.

input ComputationNode<T>

The input to the first layer.

segmentSize int

Number of layers per checkpoint segment. Default: 2

Returns

ComputationNode<T>

The output from the final layer.

Remarks

This is a convenience method for checkpointing sequential models. It automatically divides the layers into segments and applies checkpointing to each segment.

For Beginners: For models with many sequential layers (like ResNet or Transformers), this automatically applies checkpointing efficiently:

var layers = new List<Func<ComputationNode<float>, ComputationNode<float>>>
{
    x => layer1.Forward(x),
    x => layer2.Forward(x),
    x => layer3.Forward(x),
    x => layer4.Forward(x)
};

// Checkpoint every 2 layers
var output = GradientCheckpointing<float>.SequentialCheckpoint(layers, input, segmentSize: 2);