Table of Contents

Class DiffusionAttention<T>

Namespace
AiDotNet.Diffusion.Attention
Assembly
AiDotNet.dll

Memory-efficient attention layer for diffusion models using Flash Attention.

public class DiffusionAttention<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
DiffusionAttention<T>
Implements
Inherited Members

Remarks

This attention layer automatically uses Flash Attention when the sequence length exceeds a threshold, providing significant memory and performance benefits for high-resolution image generation.

For Beginners: Attention in diffusion models is computationally expensive.

For a 512x512 image at 8x downsampling:

  • Sequence length = (512/8)^2 = 4096 tokens
  • Standard attention: 4096 x 4096 = 16 million attention weights!

This class automatically uses Flash Attention for long sequences:

  • Under 256 tokens: Standard attention (faster for short sequences)
  • 256+ tokens: Flash Attention (memory-efficient, scales better)

Usage:

var attention = new DiffusionAttention<float>(
    channels: 320,
    numHeads: 8,
    spatialSize: 64);

var output = attention.Forward(input);

Constructors

DiffusionAttention(int, int, int, int, bool)

Initializes a new diffusion attention layer.

public DiffusionAttention(int channels, int numHeads = 8, int spatialSize = 64, int flashAttentionThreshold = 256, bool useCausalMask = false)

Parameters

channels int

Number of input/output channels.

numHeads int

Number of attention heads.

spatialSize int

Spatial size (height = width) of input feature maps.

flashAttentionThreshold int

Sequence length threshold for using Flash Attention (default: 256).

useCausalMask bool

Whether to use causal masking (default: false for images).

Remarks

For Beginners: Configuration tips:

  • channels: Should match your UNet block channels (e.g., 320, 640, 1280)
  • numHeads: 8 is typical; channels must be divisible by numHeads
  • spatialSize: 64 for 512px images at 8x downsampling
  • flashAttentionThreshold: Lower = more Flash Attention usage = less memory

Properties

Channels

Gets the number of channels.

public int Channels { get; }

Property Value

int

FlashAttentionEnabled

Gets whether Flash Attention is enabled.

public bool FlashAttentionEnabled { get; }

Property Value

bool

NumHeads

Gets the number of attention heads.

public int NumHeads { get; }

Property Value

int

SupportsJitCompilation

Gets whether this layer supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

True if the layer can be JIT compiled, false otherwise.

Remarks

This property indicates whether the layer has implemented ExportComputationGraph() and can benefit from JIT compilation. All layers MUST implement this property.

For Beginners: JIT compilation can make inference 5-10x faster by converting the layer's operations into optimized native code.

Layers should return false if they:

  • Have not yet implemented a working ExportComputationGraph()
  • Use dynamic operations that change based on input data
  • Are too simple to benefit from JIT compilation

When false, the layer will use the standard Forward() method instead.

SupportsTraining

Whether this layer supports training.

public override bool SupportsTraining { get; }

Property Value

bool

Methods

Backward(Tensor<T>)

Performs the backward pass through the attention layer.

public override Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Gradient from the next layer.

Returns

Tensor<T>

Gradient to pass to the previous layer.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the layer's computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the layer's operation.

Remarks

This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.

For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.

To support JIT compilation, a layer must:

  1. Implement this method to export its computation graph
  2. Set SupportsJitCompilation to true
  3. Use ComputationNode and TensorOperations to build the graph

All layers are required to implement this method, even if they set SupportsJitCompilation = false.

Forward(Tensor<T>)

Performs the forward pass through the attention layer.

public override Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor of shape [batch, channels, height, width].

Returns

Tensor<T>

Output tensor of the same shape.

Remarks

The input is reshaped from image format [B, C, H, W] to sequence format [B, H*W, C] for attention computation, then reshaped back to image format.

GetDiagnostics()

Gets diagnostic information about the layer.

public override Dictionary<string, string> GetDiagnostics()

Returns

Dictionary<string, string>

GetParameters()

Gets all layer parameters as a single vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

ResetState()

Resets the layer's internal state.

public override void ResetState()

SetParameters(Vector<T>)

Sets all layer parameters from a single vector.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

UpdateParameters(T)

Updates parameters using computed gradients.

public override void UpdateParameters(T learningRate)

Parameters

learningRate T