Class CheckpointingExtensions
Provides extension methods for gradient checkpointing on computation nodes.
public static class CheckpointingExtensions
- Inheritance
-
CheckpointingExtensions
- Inherited Members
Methods
WithCheckpoint<T>(ComputationNode<T>, Func<ComputationNode<T>, ComputationNode<T>>)
Wraps a computation with gradient checkpointing.
public static ComputationNode<T> WithCheckpoint<T>(this ComputationNode<T> input, Func<ComputationNode<T>, ComputationNode<T>> function)
Parameters
inputComputationNode<T>The input node.
functionFunc<ComputationNode<T>, ComputationNode<T>>The function to checkpoint.
Returns
- ComputationNode<T>
The checkpointed output.
Type Parameters
TThe numeric type.
Remarks
For Beginners: A convenient way to checkpoint computations:
// Instead of:
var output = GradientCheckpointing<float>.Checkpoint(() => layer(input), new[] { input });
// You can write:
var output = input.WithCheckpoint(x => layer(x));
WithSequentialCheckpoint<T>(ComputationNode<T>, IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>, int)
Applies a sequence of functions with gradient checkpointing.
public static ComputationNode<T> WithSequentialCheckpoint<T>(this ComputationNode<T> input, IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>> functions, int segmentSize = 2)
Parameters
inputComputationNode<T>The input node.
functionsIReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>The sequence of functions to apply.
segmentSizeintNumber of functions per checkpoint segment.
Returns
- ComputationNode<T>
The final output.
Type Parameters
TThe numeric type.