Table of Contents

Class AttentionNetwork<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a neural network that utilizes attention mechanisms for sequence processing.

public class AttentionNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
AttentionNetwork<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

An attention network is a specialized neural network architecture designed for sequence processing tasks. It uses attention mechanisms to dynamically focus on different parts of the input sequence when generating outputs. This allows the network to capture long-range dependencies and relationships between elements in the sequence, making it particularly effective for tasks like natural language processing, time series analysis, and other sequence-to-sequence problems.

For Beginners: This network mimics how humans pay attention to different parts of information.

Think of it like reading a complex paragraph:

  • When you try to understand a sentence, you don't focus equally on all words
  • You focus more on the important words that carry meaning
  • You also connect related words even if they're far apart

For example, in the sentence "The cat, which had a white spot on its tail, chased the mouse":

  • An attention network would connect "cat" with "chased" even though they're separated
  • It would assign different importance to different words based on context
  • This helps it understand the overall meaning better than networks that process words in isolation

This ability to selectively focus and connect distant information makes attention networks powerful for language tasks, time series prediction, and many other sequence-based problems.

Constructors

AttentionNetwork(NeuralNetworkArchitecture<T>, int, int, ILossFunction<T>?)

Initializes a new instance of the AttentionNetwork<T> class.

public AttentionNetwork(NeuralNetworkArchitecture<T> architecture, int sequenceLength, int embeddingSize, ILossFunction<T>? lossFunction = null)

Parameters

architecture NeuralNetworkArchitecture<T>

The architecture specification for the network.

sequenceLength int

The maximum length of sequences this network can process.

embeddingSize int

The size of the embeddings used to represent each element in the sequence.

lossFunction ILossFunction<T>

The loss function to use for training. If null, a default Cross-Entropy loss function will be used.

Remarks

This constructor creates an attention network with the specified architecture, sequence length, and embedding size. It initializes the network's layers according to the architecture specification or uses default layers if none are provided. If no loss function is specified, it uses Cross-Entropy Loss, which is commonly used for attention networks in tasks like machine translation or text summarization.

For Beginners: This constructor creates a new attention network with the specified settings.

The parameters you provide determine:

  • architecture: The overall design of the network (layers, connections, etc.)
  • sequenceLength: How many elements (like words) the network can process at once
  • embeddingSize: How rich the representation of each element is
  • lossFunction: How the network measures its mistakes during training (optional)

These settings control the capacity, expressiveness, and computational requirements of the network. Larger values for sequenceLength and embeddingSize give the network more capacity to handle complex tasks but require more memory and processing power.

The loss function helps the network learn by measuring how far off its predictions are. Cross-Entropy Loss is used by default because it works well for many language-related tasks.

Properties

AuxiliaryLossWeight

Gets or sets the weight for attention entropy regularization. Default is 0.01. Controls the strength of entropy regularization across attention layers.

public T AuxiliaryLossWeight { get; set; }

Property Value

T

UseAuxiliaryLoss

Gets or sets whether to use auxiliary loss (attention entropy regularization) during training. Default is false. Enable to prevent attention collapse across attention layers.

public bool UseAuxiliaryLoss { get; set; }

Property Value

bool

Methods

ComputeAuxiliaryLoss()

Computes the auxiliary loss for the AttentionNetwork, which aggregates attention entropy losses from all attention layers.

public T ComputeAuxiliaryLoss()

Returns

T

The total attention entropy loss value from all attention layers.

Remarks

This method aggregates attention entropy regularization from all attention layers in the network. It prevents attention collapse by encouraging diverse attention patterns across all layers. The loss is computed by summing entropy regularization from each AttentionLayer that has it enabled.

For Beginners: This calculates penalties from all attention mechanisms to prevent them from becoming too focused.

Attention entropy regularization:

  • Collects regularization losses from all attention layers
  • Prevents any attention layer from collapsing to single positions
  • Encourages diverse attention patterns throughout the network
  • Helps maintain robust and generalizable attention mechanisms

Why this is important:

  • Prevents attention heads from becoming redundant
  • Ensures the network uses all its attention capacity effectively
  • Improves model robustness and generalization
  • Helps prevent overfitting to specific attention patterns

Think of it like ensuring all team members (attention layers) contribute meaningfully rather than everyone just following one person's lead.

CreateNewInstance()

Creates a new instance of the attention network model.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

A new instance of the attention network model with the same configuration.

Remarks

This method creates a new instance of the attention network model with the same configuration as the current instance. It is used internally during serialization/deserialization processes to create a fresh instance that can be populated with the serialized data.

For Beginners: This method creates a copy of the model structure without copying the learned data.

Think of it like creating a blueprint of the network's architecture:

  • It includes the same structure (layers, connections, sizes)
  • It preserves the configuration settings (sequence length, embedding size)
  • It doesn't copy over any of the learned knowledge (weights, biases)

This is particularly useful when you want to save or load models, as it provides the framework that learned parameters can be loaded into.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data for the Attention Network.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

The BinaryReader to read the data from.

Remarks

This method reads the specific configuration and state of the Attention Network from a binary stream. It reconstructs the network-specific parameters to match the state of the network when it was serialized.

For Beginners: This method loads the unique settings of your Attention Network.

It reads:

  • The sequence length and embedding size
  • The configuration of each layer
  • Any other Attention Network-specific parameters

Loading these details allows you to recreate the exact same network structure that was previously saved. It's like following a detailed recipe to recreate a dish exactly as it was made before.

GetAuxiliaryLossDiagnostics()

Gets diagnostic information about the attention entropy regularization.

public Dictionary<string, string> GetAuxiliaryLossDiagnostics()

Returns

Dictionary<string, string>

A dictionary containing diagnostic information about attention patterns across all layers.

Remarks

This method provides insights into attention behavior across all attention layers, including: - Total attention entropy loss - Number of attention layers with regularization enabled - Regularization weight

For Beginners: This gives you information to monitor attention health across the entire network.

The diagnostics include:

  • Total Attention Entropy Loss: Aggregate entropy from all attention layers
  • Attention Layers Count: How many layers contribute to regularization
  • Entropy Weight: How much the regularization influences training
  • Use Auxiliary Loss: Whether network-level regularization is enabled

These values help you:

  • Monitor attention collapse across the entire network
  • Detect if attention patterns are becoming too focused
  • Tune the entropy regularization weight
  • Ensure all attention layers maintain diverse patterns

GetDiagnostics()

Gets diagnostic information about this component's state and behavior. Provides auxiliary loss diagnostics specific to attention networks.

public Dictionary<string, string> GetDiagnostics()

Returns

Dictionary<string, string>

A dictionary containing diagnostic metrics about auxiliary loss computation.

GetModelMetadata()

Gets metadata about the Attention Network model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetaData object containing information about the model.

Remarks

This method returns metadata that describes the Attention Network, including its type, architecture details, and other relevant information. This metadata can be useful for model management, documentation, and versioning.

For Beginners: This provides a summary of your network's setup and characteristics.

The metadata includes:

  • The type of model (Attention Network)
  • Details about the network's structure and capacity
  • Information about the input and output shapes

It's like a spec sheet for your network, useful for keeping track of different versions or comparing different network configurations.

InitializeLayers()

Initializes the layers of the attention network.

protected override void InitializeLayers()

Remarks

This method initializes the layers of the attention network either by using the layers provided by the user in the architecture specification or by creating default attention layers if none are provided.

For Beginners: This method sets up the building blocks of the attention network.

It does one of two things:

  1. If you provided specific layers in the architecture, it uses those
  2. If you didn't provide layers, it creates a default set of attention layers

The default layers typically include:

  • Embedding layers to convert inputs to vector representations
  • Attention layers to focus on relevant parts of the sequence
  • Feed-forward layers to process the attended information
  • Output layers to produce the final results

This flexibility allows both beginners and experts to use the network effectively.

Predict(Tensor<T>)

Makes a prediction using the current state of the Attention Network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor to make a prediction for.

Returns

Tensor<T>

The predicted output tensor after passing through all layers of the network.

Remarks

This method performs a forward pass through the network, transforming the input data through each layer to produce a final prediction. It includes input validation to ensure the provided tensor matches the expected input shape of the network.

For Beginners: This is how the network processes new data to make predictions.

The prediction process:

  1. Checks if the input data is valid and not too long
  2. Passes the input through each layer of the network
  3. Each layer transforms the data, with attention layers focusing on relevant parts
  4. The final layer produces the network's prediction

Think of it like a series of experts each looking at the data and passing their insights to the next expert, with the last one making the final decision.

Exceptions

ArgumentException

Thrown when the input sequence length exceeds the maximum allowed length.

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data for the Attention Network.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

The BinaryWriter to write the data to.

Remarks

This method writes the specific configuration and state of the Attention Network to a binary stream. It includes network-specific parameters that are essential for later reconstruction of the network.

For Beginners: This method saves the unique settings of your Attention Network.

It writes:

  • The sequence length and embedding size
  • The configuration of each layer
  • Any other Attention Network-specific parameters

Saving these details allows you to recreate the exact same network structure later. It's like writing down a detailed recipe so you can make the same dish again in the future.

Train(Tensor<T>, Tensor<T>)

Trains the Attention Network using the provided input and expected output.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input tensor used for training.

expectedOutput Tensor<T>

The expected output tensor for the given input.

Remarks

This method implements the training process for the Attention Network. It performs a forward pass, calculates the loss between the network's prediction and the expected output, and then backpropagates this error to adjust the network's parameters.

For Beginners: This is how the network learns from examples.

The training process:

  1. Makes a prediction using the current network state
  2. Compares the prediction to the correct answer to calculate the error
  3. Figures out how to adjust the network to reduce this error
  4. Updates the network's internal settings to improve future predictions

It's like a student doing practice problems, checking their answers, and learning from their mistakes to do better next time.

UpdateParameters(Vector<T>)

Updates the parameters of the attention network.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameters to update the network with.

Remarks

This method updates the parameters of each layer in the network with the provided parameter values. It distributes the parameters to each layer based on the number of parameters in each layer.

For Beginners: This method adjusts the network's internal values to improve its performance.

During training:

  • The learning algorithm calculates how the parameters should change
  • This method applies those changes to the actual parameters
  • Each layer gets its own portion of the parameter updates

Think of it like fine-tuning all the components of the network based on feedback:

  • Attention mechanisms learn to focus on more relevant parts
  • Embedding layers learn better representations of the input
  • Feed-forward layers learn to process the information more effectively