Table of Contents

Class CheckpointingExtensions

Namespace
AiDotNet.Autodiff
Assembly
AiDotNet.dll

Provides extension methods for gradient checkpointing on computation nodes.

public static class CheckpointingExtensions
Inheritance
CheckpointingExtensions
Inherited Members

Methods

WithCheckpoint<T>(ComputationNode<T>, Func<ComputationNode<T>, ComputationNode<T>>)

Wraps a computation with gradient checkpointing.

public static ComputationNode<T> WithCheckpoint<T>(this ComputationNode<T> input, Func<ComputationNode<T>, ComputationNode<T>> function)

Parameters

input ComputationNode<T>

The input node.

function Func<ComputationNode<T>, ComputationNode<T>>

The function to checkpoint.

Returns

ComputationNode<T>

The checkpointed output.

Type Parameters

T

The numeric type.

Remarks

For Beginners: A convenient way to checkpoint computations:

// Instead of:
var output = GradientCheckpointing<float>.Checkpoint(() => layer(input), new[] { input });

// You can write:
var output = input.WithCheckpoint(x => layer(x));

WithSequentialCheckpoint<T>(ComputationNode<T>, IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>, int)

Applies a sequence of functions with gradient checkpointing.

public static ComputationNode<T> WithSequentialCheckpoint<T>(this ComputationNode<T> input, IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>> functions, int segmentSize = 2)

Parameters

input ComputationNode<T>

The input node.

functions IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>

The sequence of functions to apply.

segmentSize int

Number of functions per checkpoint segment.

Returns

ComputationNode<T>

The final output.

Type Parameters

T

The numeric type.