Table of Contents

Class LayerCheckpointState<T>

Namespace
AiDotNet.Diffusion.Memory
Assembly
AiDotNet.dll

State for layer-based gradient checkpointing.

public class LayerCheckpointState<T>

Type Parameters

T

The numeric type.

Inheritance
LayerCheckpointState<T>
Inherited Members

Properties

FinalOutput

Final output from forward pass.

public Tensor<T>? FinalOutput { get; set; }

Property Value

Tensor<T>

Layers

The layers used in forward pass.

public IReadOnlyList<ILayer<T>>? Layers { get; set; }

Property Value

IReadOnlyList<ILayer<T>>

Methods

Clear()

Clears all stored state to free memory.

public void Clear()

GetActivation(int)

Gets a recomputed activation if available.

public Tensor<T>? GetActivation(int layerIndex)

Parameters

layerIndex int

Returns

Tensor<T>

GetCheckpoint(int)

Gets a checkpoint if available.

public Tensor<T>? GetCheckpoint(int layerIndex)

Parameters

layerIndex int

Returns

Tensor<T>

HasActivation(int)

Checks if an activation exists.

public bool HasActivation(int layerIndex)

Parameters

layerIndex int

Returns

bool

HasCheckpoint(int)

Checks if a checkpoint exists.

public bool HasCheckpoint(int layerIndex)

Parameters

layerIndex int

Returns

bool

SaveActivation(int, Tensor<T>)

Saves a recomputed activation.

public void SaveActivation(int layerIndex, Tensor<T> activation)

Parameters

layerIndex int
activation Tensor<T>

SaveCheckpoint(int, Tensor<T>)

Saves a checkpoint at the given layer index.

public void SaveCheckpoint(int layerIndex, Tensor<T> activation)

Parameters

layerIndex int
activation Tensor<T>