Table of Contents

Class DiffusionMemoryManager<T>

Namespace
AiDotNet.Diffusion.Memory
Assembly
AiDotNet.dll

Memory management utilities for diffusion models including gradient checkpointing, activation pooling, and model sharding integration.

public class DiffusionMemoryManager<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
DiffusionMemoryManager<T>
Inherited Members

Remarks

This class provides memory-efficient training utilities specifically designed for large diffusion models (UNet, VAE, etc.) that may not fit in GPU memory during training.

For Beginners: Training large models uses a lot of memory because we need to store: 1. The model parameters 2. Intermediate activations (outputs from each layer during forward pass) 3. Gradients for each parameter

This class helps reduce memory usage through several techniques:

Gradient Checkpointing:

  • Instead of storing all activations, only store "checkpoints"
  • During backward pass, recompute activations between checkpoints
  • Trades ~30% more compute time for ~50% less memory

Activation Pooling:

  • Reuse tensor memory instead of allocating new tensors
  • Reduces GC pressure and memory fragmentation

Model Sharding:

  • Split large models across multiple GPUs
  • Each GPU only holds part of the model

Constructors

DiffusionMemoryManager(DiffusionMemoryConfig?, IEnumerable<ILayer<T>>?)

Initializes a new instance of the DiffusionMemoryManager class.

public DiffusionMemoryManager(DiffusionMemoryConfig? config = null, IEnumerable<ILayer<T>>? layers = null)

Parameters

config DiffusionMemoryConfig

Memory configuration options.

layers IEnumerable<ILayer<T>>

Optional layers for model sharding.

Properties

CheckpointingEnabled

Whether gradient checkpointing is enabled.

public bool CheckpointingEnabled { get; }

Property Value

bool

Config

Memory configuration.

public DiffusionMemoryConfig Config { get; }

Property Value

DiffusionMemoryConfig

PoolingEnabled

Whether activation pooling is enabled.

public bool PoolingEnabled { get; }

Property Value

bool

ShardingEnabled

Whether model sharding is active.

public bool ShardingEnabled { get; }

Property Value

bool

Methods

BackwardWithCheckpointing(Tensor<T>, LayerCheckpointState<T>)

Performs backward pass with checkpointing, recomputing activations as needed.

public Tensor<T> BackwardWithCheckpointing(Tensor<T> outputGradient, LayerCheckpointState<T> state)

Parameters

outputGradient Tensor<T>

Gradient from subsequent layer.

state LayerCheckpointState<T>

Checkpoint state from forward pass.

Returns

Tensor<T>

Gradient with respect to input.

Checkpoint(Func<ComputationNode<T>>, IEnumerable<ComputationNode<T>>)

Wraps a function with gradient checkpointing for memory-efficient training.

public ComputationNode<T> Checkpoint(Func<ComputationNode<T>> function, IEnumerable<ComputationNode<T>> inputs)

Parameters

function Func<ComputationNode<T>>

The function to execute with checkpointing.

inputs IEnumerable<ComputationNode<T>>

The input computation nodes.

Returns

ComputationNode<T>

The checkpointed output node.

Remarks

For Beginners: Use this to wrap expensive computations (like attention blocks).

// Without checkpointing (stores all activations):
var output = attentionBlock.Forward(input);

// With checkpointing (recomputes during backward):
var output = memoryManager.Checkpoint(
    () => attentionBlock.Forward(inputNode),
    new[] { inputNode }
);

CheckpointSequence(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>, ComputationNode<T>)

Applies checkpointing to a sequence of layer functions.

public ComputationNode<T> CheckpointSequence(IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>> layers, ComputationNode<T> input)

Parameters

layers IReadOnlyList<Func<ComputationNode<T>, ComputationNode<T>>>

The sequence of layer forward functions.

input ComputationNode<T>

The input node.

Returns

ComputationNode<T>

The output after all layers.

EstimateMemory(int, long)

Estimates memory savings from current configuration.

public MemoryEstimate EstimateMemory(int numLayers, long activationSizeBytes)

Parameters

numLayers int

Number of layers in the model.

activationSizeBytes long

Size of activations per layer in bytes.

Returns

MemoryEstimate

Estimated memory usage information.

ForwardWithCheckpointing(IReadOnlyList<ILayer<T>>, Tensor<T>)

Executes a forward pass through layers with optional checkpointing.

public (Tensor<T> Output, LayerCheckpointState<T> State) ForwardWithCheckpointing(IReadOnlyList<ILayer<T>> layers, Tensor<T> input)

Parameters

layers IReadOnlyList<ILayer<T>>

The layers to execute.

input Tensor<T>

Input tensor.

Returns

(Tensor<T> Output, LayerCheckpointState<T> State)

Output tensor after all layers.

Remarks

This is the tensor-based equivalent for layers that don't use the autodiff system. It provides checkpointing by storing only checkpoint activations and recomputing intermediate ones during backward pass.

GetDeviceMemoryUsage()

Gets memory usage per device.

public IReadOnlyDictionary<int, long>? GetDeviceMemoryUsage()

Returns

IReadOnlyDictionary<int, long>

GetPoolStats()

Gets pooling statistics if available.

public ActivationPoolStats? GetPoolStats()

Returns

ActivationPoolStats

RentTensor(int[])

Rents a tensor from the activation pool.

public Tensor<T> RentTensor(int[] shape)

Parameters

shape int[]

Desired tensor shape.

Returns

Tensor<T>

A tensor from the pool (or newly allocated if pool unavailable).

ReturnTensor(Tensor<T>)

Returns a tensor to the activation pool for reuse.

public void ReturnTensor(Tensor<T> tensor)

Parameters

tensor Tensor<T>

The tensor to return.

ShardedBackward(Tensor<T>)

Performs backward pass through sharded model.

public Tensor<T> ShardedBackward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Gradient from subsequent layer.

Returns

Tensor<T>

Gradient with respect to input.

ShardedForward(Tensor<T>)

Performs forward pass through sharded model.

public Tensor<T> ShardedForward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor.

Returns

Tensor<T>

Output tensor.

ShardedForward(Tensor<T>, Tensor<T>?)

Performs forward pass through sharded model with context.

public Tensor<T> ShardedForward(Tensor<T> input, Tensor<T>? context)

Parameters

input Tensor<T>

Input tensor.

context Tensor<T>

Context tensor (e.g., timestep embedding).

Returns

Tensor<T>

Output tensor.

ShardedUpdateParameters(T)

Updates parameters across all shards.

public void ShardedUpdateParameters(T learningRate)

Parameters

learningRate T

Learning rate.