Class TransformerEmbeddingNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
A customizable Transformer-based embedding network. This serves as the high-performance foundation for modern sentence and document encoders.
public class TransformerEmbeddingNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IEmbeddingModel<T>
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
TransformerEmbeddingNetwork<T>
- Implements
- Derived
- Inherited Members
- Extension Methods
Remarks
This network provides a flexible implementation of the Transformer encoder architecture, enabling the generation of high-quality semantic embeddings. It supports multiple pooling strategies (Mean, Max, ClsToken) to aggregate token-level information.
For Beginners: This is a "universal reading brain." Transformers are the most powerful type of AI for understanding language because they can look at every word in a sentence at the same time and see how they all relate. This customizable version lets you decide how many layers of thinking the brain should have, and how it should summarize its thoughts into a final list of numbers (the embedding).
Constructors
TransformerEmbeddingNetwork(NeuralNetworkArchitecture<T>, ITokenizer?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, int, int, int, int, int, int, PoolingStrategy, ILossFunction<T>?, double)
Initializes a new instance of the TransformerEmbeddingNetwork.
public TransformerEmbeddingNetwork(NeuralNetworkArchitecture<T> architecture, ITokenizer? tokenizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, int vocabSize = 30522, int embeddingDimension = 768, int maxSequenceLength = 512, int numLayers = 12, int numHeads = 12, int feedForwardDim = 3072, TransformerEmbeddingNetwork<T>.PoolingStrategy poolingStrategy = PoolingStrategy.Mean, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture metadata configuration.
tokenizerITokenizerOptional tokenizer for text processing.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training.
vocabSizeintThe size of the vocabulary (default: 30522).
embeddingDimensionintThe size of the output vectors (default: 768).
maxSequenceLengthintThe maximum input sequence length (default: 512).
numLayersintThe number of transformer layers (default: 12).
numHeadsintThe number of attention heads (default: 12).
feedForwardDimintThe feed-forward hidden dimension (default: 3072).
poolingStrategyTransformerEmbeddingNetwork<T>.PoolingStrategyThe pooling method (default: Mean).
lossFunctionILossFunction<T>Optional loss function.
maxGradNormdoubleMaximum gradient norm for stability (default: 1.0).
Properties
EmbeddingDimension
Gets the dimensionality of the embedding vectors produced by this model.
public int EmbeddingDimension { get; }
Property Value
Remarks
The embedding dimension determines the size of the vector representation. Common dimensions range from 128 to 1536, with larger dimensions typically capturing more nuanced semantic relationships at the cost of memory and computation.
For Beginners: This is how many numbers represent each piece of text.
Think of it like describing a person:
- Low dimension (128): Basic traits like height, weight, age
- High dimension (768): Detailed description including personality, preferences, habits
- Very high dimension (1536): Extremely detailed profile
More dimensions = more detailed understanding, but also more storage space needed.
MaxTokens
Gets the maximum length of text (in tokens) that this model can process.
public int MaxTokens { get; }
Property Value
Remarks
Most embedding models have a maximum context length beyond which text must be truncated. Common limits range from 512 to 8192 tokens. Implementations should handle text exceeding this limit gracefully, either by truncation or raising an exception.
For Beginners: This is the maximum amount of text the model can understand at once.
Think of it like a reader's attention span:
- Short span (512 tokens): Can read about a paragraph
- Medium span (2048 tokens): Can read a few pages
- Long span (8192 tokens): Can read a short chapter
If your text is longer, it needs to be split into chunks. (A token is roughly a word, so 512 tokens ≈ 1-2 paragraphs)
Methods
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data that was not covered by the general deserialization process.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.
For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.
Embed(string)
Encodes a single string into a normalized summary vector.
public virtual Vector<T> Embed(string text)
Parameters
textstringThe text to encode.
Returns
- Vector<T>
A normalized embedding vector.
Remarks
For Beginners: This is the main use case. You give the model a sentence, it reads it with all its layers, summarizes the meaning based on your chosen pooling strategy (like taking the average meaning), and returns one final list of numbers.
EmbedAsync(string)
Asynchronously embeds a single text string into a vector representation.
public virtual Task<Vector<T>> EmbedAsync(string text)
Parameters
textstringThe text to embed.
Returns
- Task<Vector<T>>
A task representing the async operation, with the resulting vector.
EmbedBatch(IEnumerable<string>)
Encodes a collection of strings into a matrix of embeddings.
public Matrix<T> EmbedBatch(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>The texts to encode.
Returns
- Matrix<T>
A matrix where each row is an embedding for the corresponding input string.
EmbedBatchAsync(IEnumerable<string>)
Asynchronously embeds multiple text strings into vector representations in a single batch operation.
public virtual Task<Matrix<T>> EmbedBatchAsync(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>The collection of texts to embed.
Returns
- Task<Matrix<T>>
A task representing the async operation, with the resulting matrix.
GetModelMetadata()
Returns metadata about the transformer network configuration.
public override ModelMetadata<T> GetModelMetadata()
Returns
InitializeLayers()
Sets up the layer stack for the transformer network, including embedding, positional encoding, and transformer blocks.
protected override void InitializeLayers()
Remarks
For Beginners: This method builds the brain's internal architecture. It sets up the "translator" (embedding layer) to understand IDs, the "clock" (positional encoding) to understand word order, and multiple "thinking centers" (transformer encoder layers) to process complex context.
PoolOutput(Tensor<T>)
Applies the configured pooling strategy to convert token-level outputs into a sentence representation.
protected virtual Vector<T> PoolOutput(Tensor<T> output)
Parameters
outputTensor<T>The 3D tensor of token representations [batch, seq, dim].
Returns
- Vector<T>
A single pooled vector for each sentence.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data that is not covered by the general serialization process.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.
For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.
Train(Tensor<T>, Tensor<T>)
Trains the transformer model on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>expectedOutputTensor<T>
UpdateParameters(Vector<T>)
Updates the network's parameters with new values.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameter values to set.
Remarks
For Beginners: During training, a neural network's internal values (parameters) get adjusted to improve its performance. This method allows you to update all those values at once by providing a complete set of new parameters.
This is typically used by optimization algorithms that calculate better parameter values based on training data.