Class FlashAttentionConfig
- Namespace
- AiDotNet.NeuralNetworks.Attention
- Assembly
- AiDotNet.dll
Configuration options for Flash Attention algorithm.
public class FlashAttentionConfig
- Inheritance
-
FlashAttentionConfig
- Inherited Members
Remarks
Flash Attention is a memory-efficient attention algorithm that avoids materializing the full N x N attention matrix. Instead, it processes attention in tiles/blocks, computing online softmax incrementally.
For Beginners: Flash Attention is a faster way to compute attention.
Standard attention creates a huge matrix comparing every position to every other position. For long sequences (like 4096 tokens), this matrix has 16 million entries!
Flash Attention avoids creating this huge matrix by:
- Processing in small blocks that fit in fast GPU memory (SRAM)
- Computing softmax incrementally as it processes each block
- Never storing the full attention matrix
Benefits:
- 2-4x faster than standard attention
- Uses much less memory (O(N) instead of O(N^2))
- Enables training with longer sequences
Properties
BlockSizeKV
Block size for key/value processing (Bc in the paper).
public int BlockSizeKV { get; set; }
Property Value
Remarks
Controls how many key/value positions are processed together. Should typically match BlockSizeQ for square blocks.
BlockSizeQ
Block size for query processing (Br in the paper).
public int BlockSizeQ { get; set; }
Property Value
Remarks
Controls how many query positions are processed together. Larger values may be faster but use more memory. Must divide sequence length evenly for best performance.
For Beginners: This is how many "questions" we process at once.
Default of 64 works well for most GPUs:
- RTX 3090/4090: Can use 128
- Older GPUs: May need 32
Causal
Creates a configuration optimized for causal/autoregressive models.
public static FlashAttentionConfig Causal { get; }
Property Value
Default
Creates a default configuration suitable for most use cases.
public static FlashAttentionConfig Default { get; }
Property Value
DropoutProbability
Dropout probability to apply to attention weights during training.
public float DropoutProbability { get; set; }
Property Value
Remarks
Randomly zeros out attention weights to prevent overfitting. Only applied during training, not inference.
HighPerformance
Creates a configuration optimized for speed (uses more memory).
public static FlashAttentionConfig HighPerformance { get; }
Property Value
MemoryEfficient
Creates a configuration optimized for memory efficiency.
public static FlashAttentionConfig MemoryEfficient { get; }
Property Value
Precision
Numerical precision mode for attention computation.
public FlashAttentionPrecision Precision { get; set; }
Property Value
Remarks
Controls the precision used for intermediate computations. Higher precision is more accurate but slower and uses more memory.
RecomputeInBackward
Whether to enable memory-efficient backward pass with recomputation.
public bool RecomputeInBackward { get; set; }
Property Value
Remarks
When true, the backward pass recomputes attention weights instead of storing them. This significantly reduces memory usage at the cost of some additional computation.
For Beginners: This trades speed for memory during training.
Standard approach: Store attention weights, use them in backward pass Recomputation: Recompute attention weights during backward pass
Enable this when:
- Training with limited GPU memory
- Using very long sequences
- Training large models
ReturnAttentionWeights
Whether to return attention weights (for visualization/debugging).
public bool ReturnAttentionWeights { get; set; }
Property Value
Remarks
When true, materializes and returns the attention weights. This negates some memory benefits of Flash Attention but is useful for debugging. Should typically be false in production.
ScaleFactor
Scale factor for attention scores. If null, uses 1/sqrt(head_dim).
public float? ScaleFactor { get; set; }
Property Value
Remarks
The standard scale factor of 1/sqrt(d_k) prevents attention scores from becoming too large, which would cause softmax to produce very peaked distributions.
UseCausalMask
Whether to apply causal masking (for autoregressive models).
public bool UseCausalMask { get; set; }
Property Value
Remarks
When true, position i can only attend to positions j where j <= i. This is essential for language models like GPT where future tokens should not influence current predictions.
For Beginners: Causal masking prevents "cheating" in text generation.
When generating text word by word:
- The model shouldn't see future words when predicting the next word
- Causal masking hides future positions
- Set to true for GPT-style models
- Set to false for BERT-style models (bidirectional)
UseGpuKernel
Whether to use the optimized GPU kernel (when available).
public bool UseGpuKernel { get; set; }
Property Value
Remarks
When true and GPU is available, uses optimized DirectGpu kernels for Flash Attention. Falls back to CPU implementation if GPU is not available.