Table of Contents

Class DiffusionCrossAttention<T>

Namespace
AiDotNet.Diffusion.Attention
Assembly
AiDotNet.dll

Cross-attention layer for diffusion models with Flash Attention optimization.

public class DiffusionCrossAttention<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
DiffusionCrossAttention<T>
Implements
Inherited Members

Remarks

Cross-attention allows the model to attend to conditioning information (like text embeddings) when generating images. This is how text-to-image models like Stable Diffusion work.

For Beginners: Cross-attention is how the model "looks at" text while generating images.

  • Query (Q): Comes from the image features
  • Key (K) and Value (V): Come from text embeddings
  • Output: Image features enriched with text information

This enables the model to generate images that match the text description.

Constructors

DiffusionCrossAttention(int, int, int, int)

Initializes a new diffusion cross-attention layer.

public DiffusionCrossAttention(int queryDim, int contextDim, int numHeads = 8, int spatialSize = 64)

Parameters

queryDim int

Dimension of query (spatial channels).

contextDim int

Dimension of context (text embedding).

numHeads int

Number of attention heads.

spatialSize int

Spatial size of input feature maps.

Properties

ContextDim

Gets the context dimension.

public int ContextDim { get; }

Property Value

int

QueryDim

Gets the query dimension.

public int QueryDim { get; }

Property Value

int

SupportsJitCompilation

Gets whether this layer supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

True if the layer can be JIT compiled, false otherwise.

Remarks

This property indicates whether the layer has implemented ExportComputationGraph() and can benefit from JIT compilation. All layers MUST implement this property.

For Beginners: JIT compilation can make inference 5-10x faster by converting the layer's operations into optimized native code.

Layers should return false if they:

  • Have not yet implemented a working ExportComputationGraph()
  • Use dynamic operations that change based on input data
  • Are too simple to benefit from JIT compilation

When false, the layer will use the standard Forward() method instead.

SupportsTraining

Gets a value indicating whether this layer supports training.

public override bool SupportsTraining { get; }

Property Value

bool

true if the layer has trainable parameters and supports backpropagation; otherwise, false.

Remarks

This property indicates whether the layer can be trained through backpropagation. Layers with trainable parameters such as weights and biases typically return true, while layers that only perform fixed transformations (like pooling or activation layers) typically return false.

For Beginners: This property tells you if the layer can learn from data.

A value of true means:

  • The layer has parameters that can be adjusted during training
  • It will improve its performance as it sees more data
  • It participates in the learning process

A value of false means:

  • The layer doesn't have any adjustable parameters
  • It performs the same operation regardless of training
  • It doesn't need to learn (but may still be useful)

Methods

Backward(Tensor<T>)

Performs the backward pass through cross-attention.

public override Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Returns

Tensor<T>

ExportComputationGraph(List<ComputationNode<T>>)

Exports the layer's computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the layer's operation.

Remarks

This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.

For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.

To support JIT compilation, a layer must:

  1. Implement this method to export its computation graph
  2. Set SupportsJitCompilation to true
  3. Use ComputationNode and TensorOperations to build the graph

All layers are required to implement this method, even if they set SupportsJitCompilation = false.

Forward(Tensor<T>)

Performs the forward pass through cross-attention.

public override Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor (query) of shape [batch, channels, height, width].

Returns

Tensor<T>

Output tensor of the same shape.

ForwardWithContext(Tensor<T>, Tensor<T>?)

Performs the forward pass with context (conditioning).

public Tensor<T> ForwardWithContext(Tensor<T> input, Tensor<T>? context)

Parameters

input Tensor<T>

Input tensor (query) of shape [batch, channels, height, width].

context Tensor<T>

Context tensor (key/value) of shape [batch, contextLen, contextDim].

Returns

Tensor<T>

Output tensor of the same shape as input.

GetParameters()

Gets all layer parameters.

public override Vector<T> GetParameters()

Returns

Vector<T>

ResetState()

Resets the layer's internal state.

public override void ResetState()

SetParameters(Vector<T>)

Sets all layer parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

UpdateParameters(T)

Updates parameters.

public override void UpdateParameters(T learningRate)

Parameters

learningRate T