Class DiffusionMemoryManager<T>
Memory management utilities for diffusion models including gradient checkpointing, activation pooling, and model sharding integration.
public class DiffusionMemoryManager<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DiffusionMemoryManager<T>
- Inherited Members
Remarks
This class provides memory-efficient training utilities specifically designed for large diffusion models (UNet, VAE, etc.) that may not fit in GPU memory during training.
For Beginners: Training large models uses a lot of memory because we need to store: 1. The model parameters 2. Intermediate activations (outputs from each layer during forward pass) 3. Gradients for each parameter
This class helps reduce memory usage through several techniques:
Gradient Checkpointing:
- Instead of storing all activations, only store "checkpoints"
- During backward pass, recompute activations between checkpoints
- Trades ~30% more compute time for ~50% less memory
Activation Pooling:
- Reuse tensor memory instead of allocating new tensors
- Reduces GC pressure and memory fragmentation
Model Sharding:
- Split large models across multiple GPUs
- Each GPU only holds part of the model
Constructors
DiffusionMemoryManager(DiffusionMemoryConfig?, IEnumerable<ILayer<T>>?)
Initializes a new instance of the DiffusionMemoryManager class.
public DiffusionMemoryManager(DiffusionMemoryConfig? config = null, IEnumerable<ILayer<T>>? layers = null)
Parameters
configDiffusionMemoryConfigMemory configuration options.
layersIEnumerable<ILayer<T>>Optional layers for model sharding.
Properties
CheckpointingEnabled
Whether gradient checkpointing is enabled.
public bool CheckpointingEnabled { get; }
Property Value
Config
Memory configuration.
public DiffusionMemoryConfig Config { get; }
Property Value
PoolingEnabled
Whether activation pooling is enabled.
public bool PoolingEnabled { get; }
Property Value
ShardingEnabled
Whether model sharding is active.
public bool ShardingEnabled { get; }
Property Value
Methods
BackwardWithCheckpointing(Tensor<T>, LayerCheckpointState<T>)
Performs backward pass with checkpointing, recomputing activations as needed.
public Tensor<T> BackwardWithCheckpointing(Tensor<T> outputGradient, LayerCheckpointState<T> state)
Parameters
outputGradientTensor<T>Gradient from subsequent layer.
stateLayerCheckpointState<T>Checkpoint state from forward pass.
Returns
- Tensor<T>
Gradient with respect to input.
Checkpoint(Func<ComputationNode<T>>, IEnumerable<ComputationNode<T>>)
Wraps a function with gradient checkpointing for memory-efficient training.
public ComputationNode<T> Checkpoint(Func<ComputationNode<T>> function, IEnumerable<ComputationNode<T>> inputs)
Parameters
functionFunc<ComputationNode<T>>The function to execute with checkpointing.
inputsIEnumerable<ComputationNode<T>>The input computation nodes.
Returns
- ComputationNode<T>
The checkpointed output node.
Remarks
For Beginners: Use this to wrap expensive computations (like attention blocks).
// Without checkpointing (stores all activations):
var output = attentionBlock.Forward(input);
// With checkpointing (recomputes during backward):
var output = memoryManager.Checkpoint(
() => attentionBlock.Forward(inputNode),
new[] { inputNode }
);
CheckpointSequence(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>, ComputationNode<T>)
Applies checkpointing to a sequence of layer functions.
public ComputationNode<T> CheckpointSequence(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>> layers, ComputationNode<T> input)
Parameters
layersIReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>The sequence of layer forward functions.
inputComputationNode<T>The input node.
Returns
- ComputationNode<T>
The output after all layers.
EstimateMemory(int, long)
Estimates memory savings from current configuration.
public MemoryEstimate EstimateMemory(int numLayers, long activationSizeBytes)
Parameters
numLayersintNumber of layers in the model.
activationSizeByteslongSize of activations per layer in bytes.
Returns
- MemoryEstimate
Estimated memory usage information.
ForwardWithCheckpointing(IReadOnlyList<ILayer<T>>, Tensor<T>)
Executes a forward pass through layers with optional checkpointing.
public (Tensor<T> Output, LayerCheckpointState<T> State) ForwardWithCheckpointing(IReadOnlyList<ILayer<T>> layers, Tensor<T> input)
Parameters
layersIReadOnlyList<ILayer<T>>The layers to execute.
inputTensor<T>Input tensor.
Returns
- (Tensor<T> Output, LayerCheckpointState<T> State)
Output tensor after all layers.
Remarks
This is the tensor-based equivalent for layers that don't use the autodiff system. It provides checkpointing by storing only checkpoint activations and recomputing intermediate ones during backward pass.
GetDeviceMemoryUsage()
Gets memory usage per device.
public IReadOnlyDictionary<int, long>? GetDeviceMemoryUsage()
Returns
GetPoolStats()
Gets pooling statistics if available.
public ActivationPoolStats? GetPoolStats()
Returns
RentTensor(int[])
Rents a tensor from the activation pool.
public Tensor<T> RentTensor(int[] shape)
Parameters
shapeint[]Desired tensor shape.
Returns
- Tensor<T>
A tensor from the pool (or newly allocated if pool unavailable).
ReturnTensor(Tensor<T>)
Returns a tensor to the activation pool for reuse.
public void ReturnTensor(Tensor<T> tensor)
Parameters
tensorTensor<T>The tensor to return.
ShardedBackward(Tensor<T>)
Performs backward pass through sharded model.
public Tensor<T> ShardedBackward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient from subsequent layer.
Returns
- Tensor<T>
Gradient with respect to input.
ShardedForward(Tensor<T>)
Performs forward pass through sharded model.
public Tensor<T> ShardedForward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor.
Returns
- Tensor<T>
Output tensor.
ShardedForward(Tensor<T>, Tensor<T>?)
Performs forward pass through sharded model with context.
public Tensor<T> ShardedForward(Tensor<T> input, Tensor<T>? context)
Parameters
inputTensor<T>Input tensor.
contextTensor<T>Context tensor (e.g., timestep embedding).
Returns
- Tensor<T>
Output tensor.
ShardedUpdateParameters(T)
Updates parameters across all shards.
public void ShardedUpdateParameters(T learningRate)
Parameters
learningRateTLearning rate.