Class DiffusionAttention<T>
Memory-efficient attention layer for diffusion models using Flash Attention.
public class DiffusionAttention<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
LayerBase<T>DiffusionAttention<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
This attention layer automatically uses Flash Attention when the sequence length exceeds a threshold, providing significant memory and performance benefits for high-resolution image generation.
For Beginners: Attention in diffusion models is computationally expensive.
For a 512x512 image at 8x downsampling:
- Sequence length = (512/8)^2 = 4096 tokens
- Standard attention: 4096 x 4096 = 16 million attention weights!
This class automatically uses Flash Attention for long sequences:
- Under 256 tokens: Standard attention (faster for short sequences)
- 256+ tokens: Flash Attention (memory-efficient, scales better)
Usage:
var attention = new DiffusionAttention<float>(
channels: 320,
numHeads: 8,
spatialSize: 64);
var output = attention.Forward(input);
Constructors
DiffusionAttention(int, int, int, int, bool)
Initializes a new diffusion attention layer.
public DiffusionAttention(int channels, int numHeads = 8, int spatialSize = 64, int flashAttentionThreshold = 256, bool useCausalMask = false)
Parameters
channelsintNumber of input/output channels.
numHeadsintNumber of attention heads.
spatialSizeintSpatial size (height = width) of input feature maps.
flashAttentionThresholdintSequence length threshold for using Flash Attention (default: 256).
useCausalMaskboolWhether to use causal masking (default: false for images).
Remarks
For Beginners: Configuration tips:
- channels: Should match your UNet block channels (e.g., 320, 640, 1280)
- numHeads: 8 is typical; channels must be divisible by numHeads
- spatialSize: 64 for 512px images at 8x downsampling
- flashAttentionThreshold: Lower = more Flash Attention usage = less memory
Properties
Channels
Gets the number of channels.
public int Channels { get; }
Property Value
FlashAttentionEnabled
Gets whether Flash Attention is enabled.
public bool FlashAttentionEnabled { get; }
Property Value
NumHeads
Gets the number of attention heads.
public int NumHeads { get; }
Property Value
SupportsJitCompilation
Gets whether this layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
True if the layer can be JIT compiled, false otherwise.
Remarks
This property indicates whether the layer has implemented ExportComputationGraph() and can benefit from JIT compilation. All layers MUST implement this property.
For Beginners: JIT compilation can make inference 5-10x faster by converting the layer's operations into optimized native code.
Layers should return false if they:
- Have not yet implemented a working ExportComputationGraph()
- Use dynamic operations that change based on input data
- Are too simple to benefit from JIT compilation
When false, the layer will use the standard Forward() method instead.
SupportsTraining
Whether this layer supports training.
public override bool SupportsTraining { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs the backward pass through the attention layer.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient from the next layer.
Returns
- Tensor<T>
Gradient to pass to the previous layer.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the layer's computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the layer's operation.
Remarks
This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.
For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.
To support JIT compilation, a layer must:
- Implement this method to export its computation graph
- Set SupportsJitCompilation to true
- Use ComputationNode and TensorOperations to build the graph
All layers are required to implement this method, even if they set SupportsJitCompilation = false.
Forward(Tensor<T>)
Performs the forward pass through the attention layer.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor of shape [batch, channels, height, width].
Returns
- Tensor<T>
Output tensor of the same shape.
Remarks
The input is reshaped from image format [B, C, H, W] to sequence format [B, H*W, C] for attention computation, then reshaped back to image format.
GetDiagnostics()
Gets diagnostic information about the layer.
public override Dictionary<string, string> GetDiagnostics()
Returns
GetParameters()
Gets all layer parameters as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
ResetState()
Resets the layer's internal state.
public override void ResetState()
SetParameters(Vector<T>)
Sets all layer parameters from a single vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
UpdateParameters(T)
Updates parameters using computed gradients.
public override void UpdateParameters(T learningRate)
Parameters
learningRateT