Class FlashAttentionLayer<T>
- Namespace
- AiDotNet.NeuralNetworks.Attention
- Assembly
- AiDotNet.dll
A multi-head attention layer using the Flash Attention algorithm for memory-efficient computation.
public class FlashAttentionLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable
Type Parameters
TThe numeric type for computations (typically float or double).
- Inheritance
-
LayerBase<T>FlashAttentionLayer<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
FlashAttentionLayer provides the same functionality as MultiHeadAttentionLayer but uses the Flash Attention algorithm which is 2-4x faster and uses significantly less memory. It can be used as a drop-in replacement in transformer architectures.
For Beginners: This is like MultiHeadAttentionLayer but faster and more memory-efficient.
Flash Attention is a breakthrough algorithm that makes transformers much faster:
- Standard attention: O(N^2) memory, slow for long sequences
- Flash Attention: O(N) memory, 2-4x faster
Use this layer when:
- Training with long sequences (1024+ tokens)
- Training large models with limited GPU memory
- You need faster training/inference
The output is mathematically identical to standard attention - only the computation is different.
Constructors
FlashAttentionLayer(int, int, int, FlashAttentionConfig?, IActivationFunction<T>?)
Creates a new Flash Attention layer with the specified dimensions.
public FlashAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, FlashAttentionConfig? config = null, IActivationFunction<T>? activationFunction = null)
Parameters
sequenceLengthintThe length of the input sequence.
embeddingDimensionintThe dimension of each embedding vector.
headCountintThe number of attention heads.
configFlashAttentionConfigOptional Flash Attention configuration.
activationFunctionIActivationFunction<T>Optional activation function (defaults to identity).
Remarks
For Beginners: Creates a Flash Attention layer.
Parameters:
- sequenceLength: How many tokens/words in your sequence (e.g., 512, 1024, 4096)
- embeddingDimension: Size of each token's representation (e.g., 768 for BERT, 4096 for GPT-3)
- headCount: Number of attention heads (e.g., 12 for BERT-base, 96 for GPT-3)
The embeddingDimension must be divisible by headCount. Each head will have dimension = embeddingDimension / headCount.
FlashAttentionLayer(int, int, int, FlashAttentionConfig?, IVectorActivationFunction<T>?)
Creates a new Flash Attention layer with vector activation function.
public FlashAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, FlashAttentionConfig? config, IVectorActivationFunction<T>? vectorActivationFunction)
Parameters
sequenceLengthintembeddingDimensionintheadCountintconfigFlashAttentionConfigvectorActivationFunctionIVectorActivationFunction<T>
Properties
Config
Gets the Flash Attention configuration.
public FlashAttentionConfig Config { get; }
Property Value
HeadCount
Gets the number of attention heads.
public int HeadCount { get; }
Property Value
HeadDimension
Gets the dimension of each attention head.
public int HeadDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
SupportsTraining
Gets whether this layer supports training.
public override bool SupportsTraining { get; }
Property Value
Methods
Backward(Tensor<T>)
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient from the next layer.
Returns
- Tensor<T>
Gradient to pass to the previous layer.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>
Returns
Forward(Tensor<T>)
Performs the forward pass using Flash Attention.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor of shape [batch, sequenceLength, embeddingDimension].
Returns
- Tensor<T>
Output tensor of the same shape as input.
Remarks
For Beginners: This is where the Flash Attention computation happens.
The forward pass:
- Projects input to Query, Key, Value using learned weights
- Reshapes into multiple heads
- Applies Flash Attention (the fast, memory-efficient algorithm)
- Concatenates heads and projects output
Flash Attention computes the same result as standard attention but:
- Never materializes the full N x N attention matrix
- Processes in tiles that fit in fast cache memory
- Uses online softmax for numerical stability
GetDiagnostics()
Gets diagnostic information about the layer.
public override Dictionary<string, string> GetDiagnostics()
Returns
GetKeyWeights()
Gets the key projection weights.
public Matrix<T> GetKeyWeights()
Returns
- Matrix<T>
GetOutputWeights()
Gets the output projection weights.
public Matrix<T> GetOutputWeights()
Returns
- Matrix<T>
GetParameters()
Gets all layer parameters as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
GetQueryWeights()
Gets the query projection weights (for external access/debugging).
public Matrix<T> GetQueryWeights()
Returns
- Matrix<T>
GetValueWeights()
Gets the value projection weights.
public Matrix<T> GetValueWeights()
Returns
- Matrix<T>
ResetState()
Resets the layer's internal state.
public override void ResetState()
SetParameters(Vector<T>)
Sets all layer parameters from a single vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
UpdateParameters(T)
Updates parameters using computed gradients.
public override void UpdateParameters(T learningRate)
Parameters
learningRateT