Class Transformer<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Transformer neural network architecture, which is particularly effective for sequence-based tasks like natural language processing.
public class Transformer<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe data type used for calculations (typically float or double).
- Inheritance
-
Transformer<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
The Transformer architecture is a type of neural network design that uses self-attention mechanisms instead of recurrence or convolution. This approach allows the model to weigh the importance of different parts of the input sequence when producing each part of the output sequence.
The key components of a Transformer include: - Multi-head attention layers: Allow the model to focus on different parts of the input - Feed-forward networks: Process the attended information - Layer normalization: Stabilize the network during training - Residual connections: Help information flow through the network
For Beginners: A Transformer is a modern type of neural network that excels at understanding sequences of data, like sentences or time series.
Think of it like reading a book:
- When you read a sentence, some words are more important than others for understanding the meaning
- A Transformer can "pay attention" to different words based on their importance
- It can look at the entire context at once, rather than reading one word at a time
For example, in the sentence "The animal didn't cross the street because it was too wide", the Transformer can figure out that "it" refers to "the street" by paying attention to the relationship between these words.
Transformers are behind many recent AI advances, including large language models like GPT and BERT.
Constructors
Transformer(TransformerArchitecture<T>, ILossFunction<T>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?)
Creates a new Transformer neural network with the specified architecture.
public Transformer(TransformerArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null)
Parameters
architectureTransformerArchitecture<T>The architecture configuration that defines how this Transformer will be structured. This includes settings like embedding size, number of attention heads, and feed-forward dimensions.
lossFunctionILossFunction<T>optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>
Remarks
This constructor initializes a new Transformer neural network with the provided architecture configuration. It passes the architecture to the base class constructor and also stores it for use in initializing the Transformer-specific layers.
For Beginners: This is where we create our Transformer network.
When you create a new Transformer, you provide a blueprint (the architecture) that specifies:
- How many layers it should have
- How attention works in the network
- How large the various components should be
This is similar to how you might specify the size, number of rooms, and layout when building a house.
Properties
AttentionMask
Gets or sets the attention mask used in the Transformer.
public Tensor<T>? AttentionMask { get; set; }
Property Value
- Tensor<T>
Remarks
This mask is used to control which positions are attended to in the self-attention layers. It's particularly useful for tasks like sequence generation where future tokens should be masked.
AuxiliaryLossWeight
Gets or sets the weight for the attention regularization auxiliary loss.
public T AuxiliaryLossWeight { get; set; }
Property Value
- T
Remarks
This weight controls how much network-level attention regularization contributes to the total loss. Typical values range from 0.001 to 0.01.
For Beginners: This controls how much to encourage good attention throughout the network.
Common values:
- 0.005 (default): Balanced network-level regularization
- 0.001-0.003: Light regularization
- 0.008-0.01: Strong regularization
Higher values enforce stronger attention quality constraints.
UseAuxiliaryLoss
Gets or sets whether auxiliary loss (attention regularization) should be used during training.
public bool UseAuxiliaryLoss { get; set; }
Property Value
Remarks
Attention regularization aggregates auxiliary losses from all MultiHeadAttentionLayers in the network. This includes both entropy regularization and head diversity penalties.
For Beginners: This controls attention quality across the entire Transformer.
When enabled, the Transformer:
- Collects regularization from all attention layers
- Prevents attention collapse across the network
- Encourages diverse attention patterns at all levels
This is especially important for:
- Deep transformers (many layers)
- Models with many attention heads
- Tasks requiring robust attention patterns
The auxiliary loss helps maintain attention quality throughout training.
Methods
ComputeAuxiliaryLoss()
Computes the auxiliary loss for attention regularization across all attention layers.
public T ComputeAuxiliaryLoss()
Returns
- T
The computed attention regularization auxiliary loss.
Remarks
This method aggregates auxiliary losses from all MultiHeadAttentionLayers in the Transformer. It collects both entropy regularization and head diversity penalties from each attention layer. Formula: L = (1/N) * Σ_layers auxloss_i where N = number of attention layers
For Beginners: This calculates network-wide attention quality.
Transformer attention regularization works by:
- Finding all attention layers in the network
- Computing auxiliary loss for each layer (if enabled)
- Averaging these losses across all layers
- Returning the network-level regularization penalty
This helps because:
- Maintains attention quality throughout the entire network
- Prevents attention collapse at any level
- Encourages diverse attention patterns across all layers
- Improves interpretability and robustness
The auxiliary loss is added to the main task loss during training.
CreateNewInstance()
Creates a new instance of the Transformer with the same architecture and configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the Transformer with the same configuration as the current instance.
Remarks
This method creates a new Transformer neural network with the same architecture, loss function, and optimizer as the current instance. The new instance has freshly initialized parameters, making it useful for creating separate instances with identical configurations or for resetting a network while preserving its structure.
For Beginners: This creates a brand new Transformer with the same setup.
Think of it like creating a blueprint copy:
- It has the same architecture (number of layers, attention heads, etc.)
- It uses the same loss function to measure performance
- It uses the same optimizer to learn from data
- But it starts with fresh parameters (weights and biases)
This is useful when you want to:
- Start over with a fresh network but keep the same design
- Create multiple networks with identical settings for comparison
- Reset a network to its initial state
The new Transformer will need to be trained from scratch, as it doesn't inherit any of the learned knowledge from the original.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes Transformer-specific data from a binary stream.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method reads Transformer-specific configuration and state data from a binary stream. It reconstructs the Transformer's state from previously serialized data.
For Beginners: This method loads all the important details about a Transformer from a file.
It's like reconstructing the Transformer from a saved snapshot, including:
- Rebuilding its configuration (how it was set up)
- Restoring its parameter values (what it had learned)
This allows you to load a previously trained Transformer and use it immediately without having to retrain it.
GetAuxiliaryLossDiagnostics()
Gets diagnostic information about the attention regularization auxiliary loss.
public Dictionary<string, string> GetAuxiliaryLossDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic information about network-level attention regularization.
Remarks
This method returns detailed diagnostics about attention regularization across the Transformer, including aggregated losses, layer counts, and configuration parameters. This information is useful for monitoring training progress and debugging attention issues.
For Beginners: This provides information about how attention works across the network.
The diagnostics include:
- Total attention regularization loss (averaged across layers)
- Weight applied to the regularization
- Number of attention layers with regularization enabled
- Whether network-level regularization is enabled
This helps you:
- Monitor attention quality throughout the network
- Debug issues with attention collapse
- Understand the impact of regularization at the network level
- Track which layers are contributing to regularization
You can use this information to adjust regularization settings for better training results.
GetDiagnostics()
Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.
public Dictionary<string, string> GetDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().
GetModelMetadata()
Retrieves metadata about the Transformer model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the Transformer.
Remarks
This method collects and returns various pieces of information about the Transformer, including its type, architecture details, and current state.
For Beginners: This method creates a summary of the Transformer's current state and structure.
It's like creating a report card for the Transformer, including:
- What type of model it is (Transformer)
- How it's structured (number of layers, size of each layer, etc.)
- Its current training progress
- Other important details about its configuration
This information is useful for keeping track of different models, especially when you're experimenting with multiple configurations.
InitializeLayers()
Sets up the layers of the Transformer network based on the provided architecture.
protected override void InitializeLayers()
Remarks
This method either uses custom layers provided by the user or creates default Transformer layers. A typical Transformer consists of attention mechanisms, normalization layers, and feed-forward networks.
For Beginners: This method builds the actual structure of the Transformer.
It works in one of two ways:
- If you've provided your own custom layers, it uses those
- Otherwise, it creates a standard set of Transformer layers
These layers typically include:
- Attention layers (which let the model focus on relevant parts of the input)
- Normalization layers (which keep the numbers from getting too large or small)
- Feed-forward layers (which process the information)
It's like assembling the rooms and sections of a house according to the blueprint.
Predict(Tensor<T>)
Performs a forward pass through the Transformer network to generate predictions.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process.
Returns
- Tensor<T>
The output tensor containing the predictions.
Remarks
This method passes the input through each layer of the Transformer sequentially. It handles both the encoder and decoder parts of the Transformer if present.
For Beginners: This method takes your input data and runs it through the entire Transformer.
It's like sending a message through a complex machine:
- The input goes through each part of the Transformer in order
- Each layer processes the data in its own way (attention, normalization, etc.)
- The final output is the Transformer's prediction or transformation of your input
This is used when you want to use a trained Transformer to process new data.
SerializeNetworkSpecificData(BinaryWriter)
Serializes Transformer-specific data to a binary stream.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method writes Transformer-specific configuration and state data to a binary stream. It allows the Transformer's current state to be saved and later reconstructed.
For Beginners: This method saves all the important details about the Transformer to a file.
It's like taking a snapshot of the Transformer's current state, including:
- Its configuration (how it's set up)
- Its current parameter values (what it has learned so far)
This allows you to save your trained Transformer and use it again later without having to retrain it.
SetAttentionMask(Tensor<T>)
Sets the attention mask for the Transformer.
public void SetAttentionMask(Tensor<T> mask)
Parameters
maskTensor<T>The attention mask to be used in self-attention layers.
Remarks
Call this method before training or prediction to set a mask for controlling attention.
Train(Tensor<T>, Tensor<T>)
Trains the Transformer network on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor for training.
expectedOutputTensor<T>The expected output tensor.
Remarks
This method performs a forward pass, calculates the loss, and then backpropagates the error to update the network's parameters. It uses the specified loss function and optimizer.
For Beginners: This method teaches the Transformer using example data.
The process works like this:
- The Transformer makes a prediction based on the input
- We compare this prediction to the expected output
- We calculate how wrong the prediction was (the "loss")
- We adjust the Transformer's internal values to make it a little more accurate next time
This process is repeated many times with different examples to train the Transformer.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the Transformer network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.
Remarks
This method distributes the parameters to each layer based on their parameter counts. It's typically used during training when applying gradient updates.
For Beginners: This method updates the Transformer's internal values during training.
Think of parameters as the "settings" of the Transformer:
- Each layer needs a certain number of parameters to function
- During training, these parameters are constantly adjusted to improve performance
- This method takes a big list of new parameter values and gives each layer its share
It's like distributing updated parts to each section of a machine so it works better. Each layer gets exactly the number of parameters it needs.
ValidateCustomLayers(List<ILayer<T>>)
Ensures that custom layers provided for the Transformer meet the minimum requirements.
protected override void ValidateCustomLayers(List<ILayer<T>> layers)
Parameters
Remarks
A valid Transformer must include at least one attention layer and one normalization layer. Attention layers allow the model to focus on different parts of the input sequence. Normalization layers help stabilize training by normalizing the activations.
For Beginners: This method checks if your custom layers will actually work as a Transformer.
For a Transformer to function properly, it needs at minimum:
- An attention layer (which helps the model focus on important parts of the input)
- A normalization layer (which keeps the numbers stable during training)
If either of these is missing, it's like trying to build a house without walls or a foundation - it won't work!
This method checks for these essential components and raises an error if they're missing.
Exceptions
- InvalidOperationException
Thrown when the custom layers don't include required layer types.