Class LayerCheckpointState<T>
State for layer-based gradient checkpointing.
public class LayerCheckpointState<T>
Type Parameters
TThe numeric type.
- Inheritance
-
LayerCheckpointState<T>
- Inherited Members
Properties
FinalOutput
Final output from forward pass.
public Tensor<T>? FinalOutput { get; set; }
Property Value
- Tensor<T>
Layers
The layers used in forward pass.
public IReadOnlyList<ILayer<T>>? Layers { get; set; }
Property Value
- IReadOnlyList<ILayer<T>>
Methods
Clear()
Clears all stored state to free memory.
public void Clear()
GetActivation(int)
Gets a recomputed activation if available.
public Tensor<T>? GetActivation(int layerIndex)
Parameters
layerIndexint
Returns
- Tensor<T>
GetCheckpoint(int)
Gets a checkpoint if available.
public Tensor<T>? GetCheckpoint(int layerIndex)
Parameters
layerIndexint
Returns
- Tensor<T>
HasActivation(int)
Checks if an activation exists.
public bool HasActivation(int layerIndex)
Parameters
layerIndexint
Returns
HasCheckpoint(int)
Checks if a checkpoint exists.
public bool HasCheckpoint(int layerIndex)
Parameters
layerIndexint
Returns
SaveActivation(int, Tensor<T>)
Saves a recomputed activation.
public void SaveActivation(int layerIndex, Tensor<T> activation)
Parameters
layerIndexintactivationTensor<T>
SaveCheckpoint(int, Tensor<T>)
Saves a checkpoint at the given layer index.
public void SaveCheckpoint(int layerIndex, Tensor<T> activation)
Parameters
layerIndexintactivationTensor<T>