Class Word2Vec<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Word2Vec neural network implementation supporting both Skip-Gram and CBOW architectures.
public class Word2Vec<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IEmbeddingModel<T>
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
Word2Vec<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Word2Vec is a foundational technique in Natural Language Processing (NLP) that learns to map words to dense vectors of real numbers. These "embeddings" capture semantic and syntactic relationships based on the contexts in which words appear.
For Beginners: Imagine you are learning a new language by looking at thousands of newspapers. You notice that the word "bark" often appears near "dog," "tree," and "loud." You also notice that "meow" appears near "cat," "kitten," and "soft." Word2Vec is an AI that does exactly this—it builds a "map" of words where words with similar meanings or contexts are placed close together. In this map, "dog" and "cat" might be neighbors, while "dog" and "spaceship" are on opposite ends.
This implementation supports two main styles:
- Skip-Gram: Tries to guess the surrounding "context" words when given a single target word.
- CBOW: Tries to guess a single "target" word when given a group of surrounding context words.
Constructors
Word2Vec(NeuralNetworkArchitecture<T>, ITokenizer?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, int, int, int, int, Word2VecType, ILossFunction<T>?, double)
Initializes a new instance of the Word2Vec model.
public Word2Vec(NeuralNetworkArchitecture<T> architecture, ITokenizer? tokenizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, int vocabSize = 10000, int embeddingDimension = 100, int windowSize = 5, int maxTokens = 512, Word2VecType type = Word2VecType.SkipGram, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture configuration defining the neural network's metadata.
tokenizerITokenizerOptional tokenizer for text processing. Defaults to a standard BERT-style tokenizer.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training. Defaults to the Adam optimizer.
vocabSizeintThe size of the vocabulary (default: 10000).
embeddingDimensionintThe dimension of the embedding vectors (default: 100).
windowSizeintThe context window size (default: 5).
maxTokensintThe maximum tokens to process per input (default: 512).
typeWord2VecTypeThe Word2Vec architecture type (default: SkipGram).
lossFunctionILossFunction<T>Optional loss function. Defaults to Binary Cross Entropy.
maxGradNormdoubleMaximum gradient norm for stability (default: 1.0).
Remarks
For Beginners: This sets up the model's brain. You can decide how many words it should know (vocabSize), how detailed its "mental map" should be (embeddingDimension), and which learning strategy it should use (type).
Properties
EmbeddingDimension
Gets the dimensionality of the embedding vectors produced by this model.
public int EmbeddingDimension { get; }
Property Value
Remarks
The embedding dimension determines the size of the vector representation. Common dimensions range from 128 to 1536, with larger dimensions typically capturing more nuanced semantic relationships at the cost of memory and computation.
For Beginners: This is how many numbers represent each piece of text.
Think of it like describing a person:
- Low dimension (128): Basic traits like height, weight, age
- High dimension (768): Detailed description including personality, preferences, habits
- Very high dimension (1536): Extremely detailed profile
More dimensions = more detailed understanding, but also more storage space needed.
MaxTokens
Gets the maximum length of text (in tokens) that this model can process.
public int MaxTokens { get; }
Property Value
Remarks
Most embedding models have a maximum context length beyond which text must be truncated. Common limits range from 512 to 8192 tokens. Implementations should handle text exceeding this limit gracefully, either by truncation or raising an exception.
For Beginners: This is the maximum amount of text the model can understand at once.
Think of it like a reader's attention span:
- Short span (512 tokens): Can read about a paragraph
- Medium span (2048 tokens): Can read a few pages
- Long span (8192 tokens): Can read a short chapter
If your text is longer, it needs to be split into chunks. (A token is roughly a word, so 512 tokens ≈ 1-2 paragraphs)
Methods
Backward(Tensor<T>)
Performs a backward pass through the network to calculate gradients for training.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The error gradient from the loss function.
Returns
- Tensor<T>
The gradient calculated for the input.
Remarks
For Beginners: After the model makes a guess, we tell it how wrong it was. This method takes that "wrongness" and works backward through the model's brain to figure out exactly which neurons need to be adjusted to make a better guess next time.
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data that was not covered by the general deserialization process.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.
For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.
Embed(string)
Encodes a single string into a normalized embedding vector by averaging its word vectors.
public Vector<T> Embed(string text)
Parameters
textstring
Returns
- Vector<T>
EmbedAsync(string)
Asynchronously embeds a single text string into a vector representation.
public Task<Vector<T>> EmbedAsync(string text)
Parameters
textstringThe text to embed.
Returns
- Task<Vector<T>>
A task representing the async operation, with the resulting vector.
EmbedBatch(IEnumerable<string>)
Encodes a collection of strings into a matrix of embedding vectors.
public Matrix<T> EmbedBatch(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>The collection of texts to encode.
Returns
- Matrix<T>
A matrix where each row is the embedding for one input text.
Remarks
For Beginners: This is a faster way to process a whole list of sentences at once. It gives you a "list of lists," where each sentence gets its own summary vector.
EmbedBatchAsync(IEnumerable<string>)
Asynchronously embeds multiple text strings into vector representations in a single batch operation.
public Task<Matrix<T>> EmbedBatchAsync(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>The collection of texts to embed.
Returns
- Task<Matrix<T>>
A task representing the async operation, with the resulting matrix.
Forward(Tensor<T>)
Performs a forward pass through the network, typically to retrieve an embedding for given token IDs.
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>A tensor containing token indices.
Returns
- Tensor<T>
A tensor containing the resulting embeddings.
Remarks
For Beginners: This is the process of looking up a word's "address" on the model's mental map. You give it the word's ID number, and it returns the list of coordinates (embedding) for that word.
GetModelMetadata()
Retrieves detailed metadata about the Word2Vec model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing the model's configuration and complexity.
Remarks
For Beginners: This is like a "technical spec sheet" for the model. It tells you exactly how many words it knows, how complex its map is, and what strategy it used to learn.
InitializeLayers()
Initializes the layers of the Word2Vec network based on the provided architecture or standard research defaults.
protected override void InitializeLayers()
Remarks
This method sets up the "U" and "V" matrices described in the original Word2Vec paper. The "U" matrix (the first layer) acts as the input lookup table, while the "V" matrix (the second layer) acts as the context prediction head.
For Beginners: This method builds the actual structure of the model. It either uses a custom setup you provide or, by default, creates a two-step process: first, look up the numbers for a word, and second, use those numbers to try and guess other words.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data that is not covered by the general serialization process.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.
For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.
Train(Tensor<T>, Tensor<T>)
Trains the Word2Vec model on a single batch of target and context pairs.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tokens (targets or context depending on architecture).
expectedOutputTensor<T>The expected tokens the model should predict.
Remarks
For Beginners: This is how the model learns. You show it a "puzzle"—a word and its correct neighbor—and the model adjusts its internal map so that these two words are placed closer together in its memory.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network based on a provided update vector.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing updated weights and biases.
Remarks
For Beginners: After we figure out how to improve (the backward pass), this method actually changes the model's settings. It's like turning the knobs on a machine to fine-tune its performance.