Table of Contents

Class MultiHeadAttentionLayer<T>

Namespace
AiDotNet.NeuralNetworks.Layers
Assembly
AiDotNet.dll

Implements a multi-head attention layer for neural networks, a key component in transformer architectures.

public class MultiHeadAttentionLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IWeightLoadable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider

Type Parameters

T

The numeric type used for computations (typically float or double).

Inheritance
MultiHeadAttentionLayer<T>
Implements
Inherited Members

Remarks

For Beginners: Multi-head attention is like having multiple "experts" look at the same information from different perspectives. Each "head" focuses on different parts of the input, allowing the model to capture various relationships in the data simultaneously. This is similar to how you might ask several friends for advice on a decision - each person might notice different important factors.

Thread Safety: This layer is not thread-safe. Each layer instance maintains internal state during forward and backward passes. If you need concurrent execution, use separate layer instances per thread or synchronize access to shared instances.

Constructors

MultiHeadAttentionLayer(int, int, int, IActivationFunction<T>?)

Creates a new multi-head attention layer with the specified dimensions and head count.

public MultiHeadAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, IActivationFunction<T>? activationFunction = null)

Parameters

sequenceLength int

The length of the input sequence.

embeddingDimension int

The dimension of each element in the sequence.

headCount int

The number of attention heads to use.

activationFunction IActivationFunction<T>

The activation function to apply (defaults to identity function if null).

Remarks

For Beginners: This constructor sets up the attention mechanism with: - sequenceLength: How many items are in your sequence (like words in a sentence) - embeddingDimension: How much information is stored about each item - headCount: How many different "perspectives" or "experts" will analyze the data

MultiHeadAttentionLayer(int, int, int, IVectorActivationFunction<T>)

Creates a new multi-head attention layer with the specified dimensions and head count.

public MultiHeadAttentionLayer(int sequenceLength, int embeddingDimension, int headCount, IVectorActivationFunction<T> vectorActivationFunction)

Parameters

sequenceLength int

The length of the input sequence.

embeddingDimension int

The dimension of each element in the sequence.

headCount int

The number of attention heads to use.

vectorActivationFunction IVectorActivationFunction<T>

The vector activation function to apply (required to disambiguate from IActivationFunction overload).

Properties

AuxiliaryLossWeight

Gets or sets the weight for the attention entropy auxiliary loss.

public T AuxiliaryLossWeight { get; set; }

Property Value

T

Remarks

This weight controls how much attention entropy regularization contributes to the total loss. Typical values range from 0.001 to 0.01.

For Beginners: This controls how much we encourage diverse attention patterns.

Common values:

  • 0.005 (default): Balanced entropy regularization
  • 0.001-0.003: Light regularization
  • 0.008-0.01: Strong regularization

Higher values encourage more distributed attention.

HeadCount

Gets the number of attention heads in this layer.

public int HeadCount { get; }

Property Value

int

HeadDiversityWeight

Gets or sets the weight for head diversity penalty.

public T HeadDiversityWeight { get; set; }

Property Value

T

Remarks

For Beginners: This encourages different heads to learn different patterns.

Common values:

  • 0.01 (default): Moderate diversity encouragement
  • 0.005-0.008: Light diversity
  • 0.015-0.02: Strong diversity

ParameterCount

Gets the total number of trainable parameters in this layer.

public override int ParameterCount { get; }

Property Value

int

Remarks

Multi-head attention parameters are stored in multiple internal tensors (Q/K/V/O projections + output bias).

SupportsGpuExecution

Indicates whether this layer supports GPU-resident execution.

protected override bool SupportsGpuExecution { get; }

Property Value

bool

SupportsJitCompilation

Gets whether this multi-head attention layer supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

True if the layer parameters are initialized.

Remarks

This property indicates whether the layer can be JIT compiled. The layer supports JIT if: - Query, Key, Value projection weights are initialized - Output projection weights are initialized - The multi-head structure is properly configured

For Beginners: This tells you if this layer can use JIT compilation for faster inference.

The layer can be JIT compiled if:

  • All projection weight matrices are initialized (Wq, Wk, Wv, Wo)
  • The number of attention heads is configured

Multi-head attention is one of the most expensive operations in modern deep learning:

  • Used extensively in Transformers (BERT has 144 attention layers, GPT-3 has 96)
  • Each forward pass computes attention scores for all position pairs (O(n²))
  • Multiple heads process in parallel

JIT compilation provides significant speedup (5-10x) by optimizing:

  • Parallel matrix multiplications for all heads
  • Attention score computation across heads
  • Softmax operations
  • Head concatenation and output projection
  • Memory access patterns for cache efficiency

This optimization is critical for:

  • Real-time NLP applications (translation, summarization, chat)
  • Large language models (GPT, BERT, T5)
  • Vision Transformers processing high-resolution images
  • Any application using Transformer architecture

SupportsTraining

The computation engine (CPU or GPU) for vectorized operations.

public override bool SupportsTraining { get; }

Property Value

bool

UseAuxiliaryLoss

Gets or sets whether auxiliary loss (attention regularization) should be used during training.

public bool UseAuxiliaryLoss { get; set; }

Property Value

bool

Remarks

Attention regularization includes entropy regularization per head and head diversity penalties. This prevents attention collapse and encourages heads to learn different patterns.

For Beginners: This helps ensure attention heads learn diverse patterns.

Multi-head attention works best when each head specializes in different aspects:

  • Without regularization: Heads might learn redundant patterns
  • With regularization: Each head focuses on unique relationships

Two types of regularization:

  1. Entropy: Prevents attention from being too sharp (focused on one position)
  2. Diversity: Prevents heads from being too similar to each other

This helps the model:

  • Learn more robust representations
  • Utilize all attention heads effectively
  • Improve generalization to new data

Methods

Backward(Tensor<T>)

Performs the backward pass of the multi-head attention layer, calculating gradients for learning.

public override Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

The gradient flowing back from the next layer.

Returns

Tensor<T>

The gradient to be passed to the previous layer.

Remarks

For Beginners: The backward pass is how neural networks learn. Think of it like figuring out which parts of a recipe need adjustment after tasting the final dish:

  1. We first check how our output differs from what was expected (the gradient)
  2. Then we trace backward through all the calculations we did in the forward pass
  3. We determine how much each weight contributed to any errors
  4. These contributions become our gradients, which we'll use to update the weights

The complex matrix operations are just a mathematical way of figuring out "if I change this weight a little bit, how much would it improve the output?"

BackwardGpu(IGpuTensor<T>)

Performs the backward pass using GPU-resident tensors.

public override IGpuTensor<T> BackwardGpu(IGpuTensor<T> outputGradient)

Parameters

outputGradient IGpuTensor<T>

GPU-resident gradient of the loss w.r.t. output.

Returns

IGpuTensor<T>

GPU-resident gradient of the loss w.r.t. input.

ComputeAuxiliaryLoss()

Computes the auxiliary loss for attention regularization (entropy + head diversity).

public T ComputeAuxiliaryLoss()

Returns

T

The computed attention regularization auxiliary loss.

Remarks

This method computes two types of regularization: 1. Attention Entropy: Encourages attention to be distributed (not too peaked) 2. Head Diversity: Encourages different heads to learn different patterns Formula: L = entropy_weight * Σ_heads H(attention) + diversity_weight * Σ_pairs CosineSim(head_i, head_j)

For Beginners: This calculates penalties to improve attention quality.

Attention regularization works by:

  1. Measuring attention entropy for each head (prevents over-focusing)
  2. Measuring similarity between different heads (prevents redundancy)
  3. Combining these into a single auxiliary loss

This helps because:

  • Prevents attention from collapsing to single positions
  • Ensures different heads specialize in different patterns
  • Improves model robustness and interpretability

The auxiliary loss is minimized during training alongside the main task loss.

ExportComputationGraph(List<ComputationNode<T>>)

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

Returns

ComputationNode<T>

Forward(Tensor<T>)

Performs the forward pass of the multi-head attention layer.

public override Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor.

Returns

Tensor<T>

The output tensor after applying multi-head attention.

Remarks

For Beginners: The forward pass is where the layer processes the input data. Here's what happens: 1. The input is transformed into three different representations: queries, keys, and values 2. These are split into multiple "heads" (different perspectives) 3. Each head calculates how much attention to pay to different parts of the input 4. The results from all heads are combined to create the final output

Think of it like this: If you're reading a book, you might pay attention to different aspects like characters, plot, and setting all at once. Each "head" is like focusing on one of these aspects.

Forward(params Tensor<T>[])

Performs the forward pass of the layer with multiple input tensors.

public override Tensor<T> Forward(params Tensor<T>[] inputs)

Parameters

inputs Tensor<T>[]

The input tensors to process.

Returns

Tensor<T>

The output tensor after processing.

Remarks

This method implements a default forward pass for layers that accept multiple inputs. By default, it concatenates the inputs along the channel dimension. Derived classes can override this method to implement more specific behavior for multiple inputs.

For Beginners: This method handles processing multiple inputs through the layer.

When a layer needs to combine multiple data sources:

  • This method takes all the input tensors
  • By default, it combines them by stacking them along the channel dimension
  • It checks that the inputs are compatible (same shape except for channels)
  • It then passes the combined data forward

For example, if combining features from two sources each with 10 channels, this would create a tensor with 20 channels by default.

Specialized layers can override this to combine inputs in different ways.

Exceptions

ArgumentException

Thrown when no input tensors are provided or when input tensors have incompatible shapes.

ForwardGpu(params IGpuTensor<T>[])

GPU-resident forward pass for multi-head attention. Performs all projections and attention computation on GPU without downloading intermediate results.

public override IGpuTensor<T> ForwardGpu(params IGpuTensor<T>[] inputs)

Parameters

inputs IGpuTensor<T>[]

Returns

IGpuTensor<T>

GPU-resident output tensor.

GetAuxiliaryLossDiagnostics()

Gets diagnostic information about the attention regularization auxiliary loss.

public Dictionary<string, string> GetAuxiliaryLossDiagnostics()

Returns

Dictionary<string, string>

A dictionary containing diagnostic information about attention regularization.

Remarks

This method returns detailed diagnostics about attention regularization, including entropy loss, diversity loss, and configuration parameters. This information is useful for monitoring training progress and debugging.

For Beginners: This provides information about how attention regularization is working.

The diagnostics include:

  • Total entropy loss (how distributed attention patterns are)
  • Total diversity loss (how different heads are from each other)
  • Weights applied to each loss component
  • Whether regularization is enabled
  • Number of attention heads

This helps you:

  • Monitor if attention is becoming too sharp or redundant
  • Debug issues with head specialization
  • Understand the impact of regularization on learning

You can use this information to adjust regularization weights for better results.

GetDiagnostics()

Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.

public override Dictionary<string, string> GetDiagnostics()

Returns

Dictionary<string, string>

A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().

GetKeyWeights()

Gets the key projection weights tensor for JIT compilation.

public Tensor<T> GetKeyWeights()

Returns

Tensor<T>

GetOutputWeights()

Gets the output projection weights tensor for JIT compilation.

public Tensor<T> GetOutputWeights()

Returns

Tensor<T>

GetParameters()

Extracts all parameters (weights and biases) from the layer into a single vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

A vector containing all parameters of the layer.

Remarks

For Beginners: This method collects all the layer's adjustable values (weights and biases) into a single list. Think of it like taking inventory of all the ingredients in a recipe. This is useful for saving the model's state or for optimization algorithms that need to work with all parameters at once.

GetQueryWeights()

Gets the query projection weights tensor for JIT compilation.

public Tensor<T> GetQueryWeights()

Returns

Tensor<T>

GetValueWeights()

Gets the value projection weights tensor for JIT compilation.

public Tensor<T> GetValueWeights()

Returns

Tensor<T>

ResetState()

Resets the internal state of the multi-head attention layer.

public override void ResetState()

Remarks

This method clears all cached values from previous forward and backward passes, effectively resetting the layer to its initial state but keeping the learned weights. This is useful when starting a new training sequence or when you want to clear any temporary data without losing the layer's learned parameters.

For Beginners: Think of this like clearing your scratch paper after solving a math problem. You're keeping all the knowledge you've gained (the weights), but you're getting rid of all the intermediate calculations (cached values) to make room for new work.

This is particularly important in neural networks because:

  1. It frees up memory by removing data we no longer need
  2. It ensures that each new input is processed with a "clean slate"
  3. It prevents old calculations from affecting new ones, which could lead to incorrect results

SetParameters(Vector<T>)

Sets all parameters (weights and biases) of the layer from a single vector.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters to set in the layer.

Remarks

For Beginners: This method does the opposite of GetParameters - it takes a list of values and distributes them back into the layer's weights and biases. It's like restocking all the ingredients in your kitchen from a single shopping bag, putting each item in its proper place. This is useful when loading a saved model or when optimization algorithms have computed improved parameter values.

UpdateParameters(T)

Updates the layer's parameters (weights and biases) using the calculated gradients.

public override void UpdateParameters(T learningRate)

Parameters

learningRate T

The learning rate that controls how much to adjust the parameters.

Remarks

For Beginners: This method is like adjusting a recipe based on feedback. The learning rate is how bold we are with our changes - a higher rate means bigger adjustments, while a lower rate means more cautious, smaller adjustments. The gradients tell us which direction to adjust each parameter to improve the network's performance.