Class MatryoshkaEmbedding<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Matryoshka Representation Learning (MRL) neural network implementation. Learns nested embeddings where smaller prefixes of the full vector are valid representations.
public class MatryoshkaEmbedding<T> : TransformerEmbeddingNetwork<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IEmbeddingModel<T>
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
MatryoshkaEmbedding<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Matryoshka Representation Learning (MRL) is a technique that enables a single model to adapt its embedding dimension to the requirements of the downstream task. It optimizes for multiple dimensions simultaneously, ensuring high accuracy even when using truncated vector prefixes.
For Beginners: Imagine a Russian nesting doll (a Matryoshka). Inside the big doll is a smaller one, and inside that is an even smaller one. MRL works the same way: it creates a long list of numbers to describe a sentence, but it makes sure that the first few numbers are a "perfect miniature" of the whole meaning. This lets you use a tiny list for a fast search and the full list when you need total accuracy.
Constructors
MatryoshkaEmbedding(NeuralNetworkArchitecture<T>, ITokenizer?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, int, int, int[]?, int, int, int, int, PoolingStrategy, ILossFunction<T>?, double)
Initializes a new instance of the MatryoshkaEmbedding model.
public MatryoshkaEmbedding(NeuralNetworkArchitecture<T> architecture, ITokenizer? tokenizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, int vocabSize = 30522, int maxEmbeddingDimension = 1536, int[]? nestedDimensions = null, int maxSequenceLength = 512, int numLayers = 12, int numHeads = 12, int feedForwardDim = 3072, TransformerEmbeddingNetwork<T>.PoolingStrategy poolingStrategy = PoolingStrategy.ClsToken, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>tokenizerITokenizeroptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>vocabSizeintmaxEmbeddingDimensionintnestedDimensionsint[]maxSequenceLengthintnumLayersintnumHeadsintfeedForwardDimintpoolingStrategyTransformerEmbeddingNetwork<T>.PoolingStrategylossFunctionILossFunction<T>maxGradNormdouble
Methods
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data that was not covered by the general deserialization process.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.
For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.
Embed(string)
Encodes a single string into a normalized summary vector.
public override Vector<T> Embed(string text)
Parameters
textstringThe text to encode.
Returns
- Vector<T>
A normalized embedding vector.
Remarks
For Beginners: This is the main use case. You give the model a sentence, it reads it with all its layers, summarizes the meaning based on your chosen pooling strategy (like taking the average meaning), and returns one final list of numbers.
EmbedAsync(string)
Asynchronously embeds a single text string into a vector representation.
public override Task<Vector<T>> EmbedAsync(string text)
Parameters
textstringThe text to embed.
Returns
- Task<Vector<T>>
A task representing the async operation, with the resulting vector.
EmbedBatchAsync(IEnumerable<string>)
Asynchronously embeds multiple text strings into vector representations in a single batch operation.
public override Task<Matrix<T>> EmbedBatchAsync(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>The collection of texts to embed.
Returns
- Task<Matrix<T>>
A task representing the async operation, with the resulting matrix.
EmbedResized(string, int)
Encodes text into a truncated and re-normalized embedding of the requested dimension.
public Vector<T> EmbedResized(string text, int dimension)
Parameters
Returns
- Vector<T>
A normalized vector containing only the first 'dimension' elements.
Remarks
For Beginners: This is like picking which doll you want to use. If you want a lightning-fast search, you might only take the first 64 numbers. This method slices the full list and makes sure it's mathematically consistent for comparison.
GetModelMetadata()
Retrieves metadata about the Matryoshka configuration.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
Metadata containing model type and nested dimension information.
InitializeLayers()
Configures the transformer encoder and projection layers for the MRL architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method builds the model's "nested organization center." It sets up a deep brain that learns to sort information by importance, making sure the absolute most important facts are always at the beginning of its numerical output.
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data that is not covered by the general serialization process.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.
For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.