Class MultiHeadAttentionLayer<T>
- Namespace
- AiDotNet.NeuralNetworks.Layers
- Assembly
- AiDotNet.dll
Implements a multi-head attention layer for neural networks, a key component in transformer architectures.
public class MultiHeadAttentionLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IWeightLoadable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe numeric type used for computations (typically float or double).
- Inheritance
-
LayerBase<T>MultiHeadAttentionLayer<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
For Beginners: Multi-head attention is like having multiple "experts" look at the same information from different perspectives. Each "head" focuses on different parts of the input, allowing the model to capture various relationships in the data simultaneously. This is similar to how you might ask several friends for advice on a decision - each person might notice different important factors.
Thread Safety: This layer is not thread-safe. Each layer instance maintains internal state during forward and backward passes. If you need concurrent execution, use separate layer instances per thread or synchronize access to shared instances.
Constructors
MultiHeadAttentionLayer(int, int, int, IActivationFunction<T>?)
Creates a new multi-head attention layer with the specified dimensions and head count.
public MultiHeadAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, IActivationFunction<T>? activationFunction = null)
Parameters
sequenceLengthintThe length of the input sequence.
embeddingDimensionintThe dimension of each element in the sequence.
headCountintThe number of attention heads to use.
activationFunctionIActivationFunction<T>The activation function to apply (defaults to identity function if null).
Remarks
For Beginners: This constructor sets up the attention mechanism with: - sequenceLength: How many items are in your sequence (like words in a sentence) - embeddingDimension: How much information is stored about each item - headCount: How many different "perspectives" or "experts" will analyze the data
MultiHeadAttentionLayer(int, int, int, IVectorActivationFunction<T>)
Creates a new multi-head attention layer with the specified dimensions and head count.
public MultiHeadAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, IVectorActivationFunction<T> vectorActivationFunction)
Parameters
sequenceLengthintThe length of the input sequence.
embeddingDimensionintThe dimension of each element in the sequence.
headCountintThe number of attention heads to use.
vectorActivationFunctionIVectorActivationFunction<T>The vector activation function to apply (required to disambiguate from IActivationFunction overload).
Properties
AuxiliaryLossWeight
Gets or sets the weight for the attention entropy auxiliary loss.
public T AuxiliaryLossWeight { get; set; }
Property Value
- T
Remarks
This weight controls how much attention entropy regularization contributes to the total loss. Typical values range from 0.001 to 0.01.
For Beginners: This controls how much we encourage diverse attention patterns.
Common values:
- 0.005 (default): Balanced entropy regularization
- 0.001-0.003: Light regularization
- 0.008-0.01: Strong regularization
Higher values encourage more distributed attention.
HeadCount
Gets the number of attention heads in this layer.
public int HeadCount { get; }
Property Value
HeadDiversityWeight
Gets or sets the weight for head diversity penalty.
public T HeadDiversityWeight { get; set; }
Property Value
- T
Remarks
For Beginners: This encourages different heads to learn different patterns.
Common values:
- 0.01 (default): Moderate diversity encouragement
- 0.005-0.008: Light diversity
- 0.015-0.02: Strong diversity
ParameterCount
Gets the total number of trainable parameters in this layer.
public override int ParameterCount { get; }
Property Value
Remarks
Multi-head attention parameters are stored in multiple internal tensors (Q/K/V/O projections + output bias).
SupportsGpuExecution
Indicates whether this layer supports GPU-resident execution.
protected override bool SupportsGpuExecution { get; }
Property Value
SupportsJitCompilation
Gets whether this multi-head attention layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
True if the layer parameters are initialized.
Remarks
This property indicates whether the layer can be JIT compiled. The layer supports JIT if: - Query, Key, Value projection weights are initialized - Output projection weights are initialized - The multi-head structure is properly configured
For Beginners: This tells you if this layer can use JIT compilation for faster inference.
The layer can be JIT compiled if:
- All projection weight matrices are initialized (Wq, Wk, Wv, Wo)
- The number of attention heads is configured
Multi-head attention is one of the most expensive operations in modern deep learning:
- Used extensively in Transformers (BERT has 144 attention layers, GPT-3 has 96)
- Each forward pass computes attention scores for all position pairs (O(n²))
- Multiple heads process in parallel
JIT compilation provides significant speedup (5-10x) by optimizing:
- Parallel matrix multiplications for all heads
- Attention score computation across heads
- Softmax operations
- Head concatenation and output projection
- Memory access patterns for cache efficiency
This optimization is critical for:
- Real-time NLP applications (translation, summarization, chat)
- Large language models (GPT, BERT, T5)
- Vision Transformers processing high-resolution images
- Any application using Transformer architecture
SupportsTraining
The computation engine (CPU or GPU) for vectorized operations.
public override bool SupportsTraining { get; }
Property Value
UseAuxiliaryLoss
Gets or sets whether auxiliary loss (attention regularization) should be used during training.
public bool UseAuxiliaryLoss { get; set; }
Property Value
Remarks
Attention regularization includes entropy regularization per head and head diversity penalties. This prevents attention collapse and encourages heads to learn different patterns.
For Beginners: This helps ensure attention heads learn diverse patterns.
Multi-head attention works best when each head specializes in different aspects:
- Without regularization: Heads might learn redundant patterns
- With regularization: Each head focuses on unique relationships
Two types of regularization:
- Entropy: Prevents attention from being too sharp (focused on one position)
- Diversity: Prevents heads from being too similar to each other
This helps the model:
- Learn more robust representations
- Utilize all attention heads effectively
- Improve generalization to new data
Methods
Backward(Tensor<T>)
Performs the backward pass of the multi-head attention layer, calculating gradients for learning.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient flowing back from the next layer.
Returns
- Tensor<T>
The gradient to be passed to the previous layer.
Remarks
For Beginners: The backward pass is how neural networks learn. Think of it like figuring out which parts of a recipe need adjustment after tasting the final dish:
- We first check how our output differs from what was expected (the gradient)
- Then we trace backward through all the calculations we did in the forward pass
- We determine how much each weight contributed to any errors
- These contributions become our gradients, which we'll use to update the weights
The complex matrix operations are just a mathematical way of figuring out "if I change this weight a little bit, how much would it improve the output?"
BackwardGpu(IGpuTensor<T>)
Performs the backward pass using GPU-resident tensors.
public override IGpuTensor<T> BackwardGpu(IGpuTensor<T> outputGradient)
Parameters
outputGradientIGpuTensor<T>GPU-resident gradient of the loss w.r.t. output.
Returns
- IGpuTensor<T>
GPU-resident gradient of the loss w.r.t. input.
ComputeAuxiliaryLoss()
Computes the auxiliary loss for attention regularization (entropy + head diversity).
public T ComputeAuxiliaryLoss()
Returns
- T
The computed attention regularization auxiliary loss.
Remarks
This method computes two types of regularization: 1. Attention Entropy: Encourages attention to be distributed (not too peaked) 2. Head Diversity: Encourages different heads to learn different patterns Formula: L = entropy_weight * Σ_heads H(attention) + diversity_weight * Σ_pairs CosineSim(head_i, head_j)
For Beginners: This calculates penalties to improve attention quality.
Attention regularization works by:
- Measuring attention entropy for each head (prevents over-focusing)
- Measuring similarity between different heads (prevents redundancy)
- Combining these into a single auxiliary loss
This helps because:
- Prevents attention from collapsing to single positions
- Ensures different heads specialize in different patterns
- Improves model robustness and interpretability
The auxiliary loss is minimized during training alongside the main task loss.
ExportComputationGraph(List<ComputationNode<T>>)
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>
Returns
Forward(Tensor<T>)
Performs the forward pass of the multi-head attention layer.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>The input tensor.
Returns
- Tensor<T>
The output tensor after applying multi-head attention.
Remarks
For Beginners: The forward pass is where the layer processes the input data. Here's what happens: 1. The input is transformed into three different representations: queries, keys, and values 2. These are split into multiple "heads" (different perspectives) 3. Each head calculates how much attention to pay to different parts of the input 4. The results from all heads are combined to create the final output
Think of it like this: If you're reading a book, you might pay attention to different aspects like characters, plot, and setting all at once. Each "head" is like focusing on one of these aspects.
Forward(params Tensor<T>[])
Performs the forward pass of the layer with multiple input tensors.
public override Tensor<T> Forward(params Tensor<T>[] inputs)
Parameters
inputsTensor<T>[]The input tensors to process.
Returns
- Tensor<T>
The output tensor after processing.
Remarks
This method implements a default forward pass for layers that accept multiple inputs. By default, it concatenates the inputs along the channel dimension. Derived classes can override this method to implement more specific behavior for multiple inputs.
For Beginners: This method handles processing multiple inputs through the layer.
When a layer needs to combine multiple data sources:
- This method takes all the input tensors
- By default, it combines them by stacking them along the channel dimension
- It checks that the inputs are compatible (same shape except for channels)
- It then passes the combined data forward
For example, if combining features from two sources each with 10 channels, this would create a tensor with 20 channels by default.
Specialized layers can override this to combine inputs in different ways.
Exceptions
- ArgumentException
Thrown when no input tensors are provided or when input tensors have incompatible shapes.
ForwardGpu(params IGpuTensor<T>[])
GPU-resident forward pass for multi-head attention. Performs all projections and attention computation on GPU without downloading intermediate results.
public override IGpuTensor<T> ForwardGpu(params IGpuTensor<T>[] inputs)
Parameters
inputsIGpuTensor<T>[]
Returns
- IGpuTensor<T>
GPU-resident output tensor.
GetAuxiliaryLossDiagnostics()
Gets diagnostic information about the attention regularization auxiliary loss.
public Dictionary<string, string> GetAuxiliaryLossDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic information about attention regularization.
Remarks
This method returns detailed diagnostics about attention regularization, including entropy loss, diversity loss, and configuration parameters. This information is useful for monitoring training progress and debugging.
For Beginners: This provides information about how attention regularization is working.
The diagnostics include:
- Total entropy loss (how distributed attention patterns are)
- Total diversity loss (how different heads are from each other)
- Weights applied to each loss component
- Whether regularization is enabled
- Number of attention heads
This helps you:
- Monitor if attention is becoming too sharp or redundant
- Debug issues with head specialization
- Understand the impact of regularization on learning
You can use this information to adjust regularization weights for better results.
GetDiagnostics()
Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.
public override Dictionary<string, string> GetDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().
GetKeyWeights()
Gets the key projection weights tensor for JIT compilation.
public Tensor<T> GetKeyWeights()
Returns
- Tensor<T>
GetOutputWeights()
Gets the output projection weights tensor for JIT compilation.
public Tensor<T> GetOutputWeights()
Returns
- Tensor<T>
GetParameters()
Extracts all parameters (weights and biases) from the layer into a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all parameters of the layer.
Remarks
For Beginners: This method collects all the layer's adjustable values (weights and biases) into a single list. Think of it like taking inventory of all the ingredients in a recipe. This is useful for saving the model's state or for optimization algorithms that need to work with all parameters at once.
GetQueryWeights()
Gets the query projection weights tensor for JIT compilation.
public Tensor<T> GetQueryWeights()
Returns
- Tensor<T>
GetValueWeights()
Gets the value projection weights tensor for JIT compilation.
public Tensor<T> GetValueWeights()
Returns
- Tensor<T>
ResetState()
Resets the internal state of the multi-head attention layer.
public override void ResetState()
Remarks
This method clears all cached values from previous forward and backward passes, effectively resetting the layer to its initial state but keeping the learned weights. This is useful when starting a new training sequence or when you want to clear any temporary data without losing the layer's learned parameters.
For Beginners: Think of this like clearing your scratch paper after solving a math problem. You're keeping all the knowledge you've gained (the weights), but you're getting rid of all the intermediate calculations (cached values) to make room for new work.
This is particularly important in neural networks because:
- It frees up memory by removing data we no longer need
- It ensures that each new input is processed with a "clean slate"
- It prevents old calculations from affecting new ones, which could lead to incorrect results
SetParameters(Vector<T>)
Sets all parameters (weights and biases) of the layer from a single vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters to set in the layer.
Remarks
For Beginners: This method does the opposite of GetParameters - it takes a list of values and distributes them back into the layer's weights and biases. It's like restocking all the ingredients in your kitchen from a single shopping bag, putting each item in its proper place. This is useful when loading a saved model or when optimization algorithms have computed improved parameter values.
UpdateParameters(T)
Updates the layer's parameters (weights and biases) using the calculated gradients.
public override void UpdateParameters(T learningRate)
Parameters
learningRateTThe learning rate that controls how much to adjust the parameters.
Remarks
For Beginners: This method is like adjusting a recipe based on feedback. The learning rate is how bold we are with our changes - a higher rate means bigger adjustments, while a lower rate means more cautious, smaller adjustments. The gradients tell us which direction to adjust each parameter to improve the network's performance.