Table of Contents

Class FlashAttentionLayer<T>

Namespace
AiDotNet.NeuralNetworks.Attention
Assembly
AiDotNet.dll

A multi-head attention layer using the Flash Attention algorithm for memory-efficient computation.

public class FlashAttentionLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable

Type Parameters

T

The numeric type for computations (typically float or double).

Inheritance
FlashAttentionLayer<T>
Implements
Inherited Members

Remarks

FlashAttentionLayer provides the same functionality as MultiHeadAttentionLayer but uses the Flash Attention algorithm which is 2-4x faster and uses significantly less memory. It can be used as a drop-in replacement in transformer architectures.

For Beginners: This is like MultiHeadAttentionLayer but faster and more memory-efficient.

Flash Attention is a breakthrough algorithm that makes transformers much faster:

  • Standard attention: O(N^2) memory, slow for long sequences
  • Flash Attention: O(N) memory, 2-4x faster

Use this layer when:

  • Training with long sequences (1024+ tokens)
  • Training large models with limited GPU memory
  • You need faster training/inference

The output is mathematically identical to standard attention - only the computation is different.

Constructors

FlashAttentionLayer(int, int, int, FlashAttentionConfig?, IActivationFunction<T>?)

Creates a new Flash Attention layer with the specified dimensions.

public FlashAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, FlashAttentionConfig? config = null, IActivationFunction<T>? activationFunction = null)

Parameters

sequenceLength int

The length of the input sequence.

embeddingDimension int

The dimension of each embedding vector.

headCount int

The number of attention heads.

config FlashAttentionConfig

Optional Flash Attention configuration.

activationFunction IActivationFunction<T>

Optional activation function (defaults to identity).

Remarks

For Beginners: Creates a Flash Attention layer.

Parameters:

  • sequenceLength: How many tokens/words in your sequence (e.g., 512, 1024, 4096)
  • embeddingDimension: Size of each token's representation (e.g., 768 for BERT, 4096 for GPT-3)
  • headCount: Number of attention heads (e.g., 12 for BERT-base, 96 for GPT-3)

The embeddingDimension must be divisible by headCount. Each head will have dimension = embeddingDimension / headCount.

FlashAttentionLayer(int, int, int, FlashAttentionConfig?, IVectorActivationFunction<T>?)

Creates a new Flash Attention layer with vector activation function.

public FlashAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, FlashAttentionConfig? config, IVectorActivationFunction<T>? vectorActivationFunction)

Parameters

sequenceLength int
embeddingDimension int
headCount int
config FlashAttentionConfig
vectorActivationFunction IVectorActivationFunction<T>

Properties

Config

Gets the Flash Attention configuration.

public FlashAttentionConfig Config { get; }

Property Value

FlashAttentionConfig

HeadCount

Gets the number of attention heads.

public int HeadCount { get; }

Property Value

int

HeadDimension

Gets the dimension of each attention head.

public int HeadDimension { get; }

Property Value

int

SupportsJitCompilation

Gets whether this layer supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

SupportsTraining

Gets whether this layer supports training.

public override bool SupportsTraining { get; }

Property Value

bool

Methods

Backward(Tensor<T>)

public override Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Gradient from the next layer.

Returns

Tensor<T>

Gradient to pass to the previous layer.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

Returns

ComputationNode<T>

Forward(Tensor<T>)

Performs the forward pass using Flash Attention.

public override Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor of shape [batch, sequenceLength, embeddingDimension].

Returns

Tensor<T>

Output tensor of the same shape as input.

Remarks

For Beginners: This is where the Flash Attention computation happens.

The forward pass:

  1. Projects input to Query, Key, Value using learned weights
  2. Reshapes into multiple heads
  3. Applies Flash Attention (the fast, memory-efficient algorithm)
  4. Concatenates heads and projects output

Flash Attention computes the same result as standard attention but:

  • Never materializes the full N x N attention matrix
  • Processes in tiles that fit in fast cache memory
  • Uses online softmax for numerical stability

GetDiagnostics()

Gets diagnostic information about the layer.

public override Dictionary<string, string> GetDiagnostics()

Returns

Dictionary<string, string>

GetKeyWeights()

Gets the key projection weights.

public Matrix<T> GetKeyWeights()

Returns

Matrix<T>

GetOutputWeights()

Gets the output projection weights.

public Matrix<T> GetOutputWeights()

Returns

Matrix<T>

GetParameters()

Gets all layer parameters as a single vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

GetQueryWeights()

Gets the query projection weights (for external access/debugging).

public Matrix<T> GetQueryWeights()

Returns

Matrix<T>

GetValueWeights()

Gets the value projection weights.

public Matrix<T> GetValueWeights()

Returns

Matrix<T>

ResetState()

Resets the layer's internal state.

public override void ResetState()

SetParameters(Vector<T>)

Sets all layer parameters from a single vector.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

UpdateParameters(T)

Updates parameters using computed gradients.

public override void UpdateParameters(T learningRate)

Parameters

learningRate T