Class FlamingoNeuralNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Flamingo neural network for in-context visual learning and few-shot tasks.
public class FlamingoNeuralNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IFlamingoModel<T>, IMultimodalEmbedding<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
FlamingoNeuralNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Flamingo is a visual language model that excels at few-shot learning. It uses a Perceiver Resampler to compress visual features and gated cross-attention layers to integrate visual information into a frozen language model.
Constructors
FlamingoNeuralNetwork(NeuralNetworkArchitecture<T>, int, int, int, int, int, int, int, int, int, int, int, int, LanguageModelBackbone, int, ITokenizer?, IOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Initializes a new instance using native layers.
public FlamingoNeuralNetwork(NeuralNetworkArchitecture<T> architecture, int embeddingDimension = 768, int maxSequenceLength = 2048, int imageSize = 224, int channels = 3, int numPerceiverTokens = 64, int maxImagesInContext = 5, int visionHiddenDim = 1024, int lmHiddenDim = 2048, int numVisionLayers = 24, int numLmLayers = 32, int numHeads = 16, int vocabularySize = 32000, LanguageModelBackbone languageModelBackbone = LanguageModelBackbone.Chinchilla, int numPerceiverLayers = 6, ITokenizer? tokenizer = null, IOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
architectureNeuralNetworkArchitecture<T>embeddingDimensionintmaxSequenceLengthintimageSizeintchannelsintnumPerceiverTokensintmaxImagesInContextintvisionHiddenDimintlmHiddenDimintnumVisionLayersintnumLmLayersintnumHeadsintvocabularySizeintlanguageModelBackboneLanguageModelBackbonenumPerceiverLayersinttokenizerITokenizeroptimizerIOptimizer<T, Tensor<T>, Tensor<T>>lossFunctionILossFunction<T>
FlamingoNeuralNetwork(NeuralNetworkArchitecture<T>, string, string, ITokenizer, int, int, int, int, int, IOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Initializes a new instance using ONNX models.
public FlamingoNeuralNetwork(NeuralNetworkArchitecture<T> architecture, string visionEncoderPath, string languageModelPath, ITokenizer tokenizer, int embeddingDimension = 768, int maxSequenceLength = 2048, int imageSize = 224, int numPerceiverTokens = 64, int maxImagesInContext = 5, IOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
architectureNeuralNetworkArchitecture<T>visionEncoderPathstringlanguageModelPathstringtokenizerITokenizerembeddingDimensionintmaxSequenceLengthintimageSizeintnumPerceiverTokensintmaxImagesInContextintoptimizerIOptimizer<T, Tensor<T>, Tensor<T>>lossFunctionILossFunction<T>
Properties
EmbeddingDimension
Gets the dimensionality of the embedding space.
public int EmbeddingDimension { get; }
Property Value
ImageSize
Gets the expected image size (square images: ImageSize x ImageSize pixels).
public int ImageSize { get; }
Property Value
LanguageModelBackbone
Gets the language model backbone used for generation.
public LanguageModelBackbone LanguageModelBackbone { get; }
Property Value
Remarks
Flamingo typically uses Chinchilla as the backbone.
MaxImagesInContext
Gets the maximum number of images that can be processed in a single context.
public int MaxImagesInContext { get; }
Property Value
MaxSequenceLength
Gets the maximum sequence length for text input.
public int MaxSequenceLength { get; }
Property Value
NumPerceiverTokens
Gets the number of visual tokens per image after the Perceiver Resampler.
public int NumPerceiverTokens { get; }
Property Value
Remarks
The Perceiver Resampler compresses visual features to a fixed number of tokens (typically 64) regardless of input image size. This enables efficient processing of multiple images in context.
ParameterCount
Gets the total number of parameters in the model.
public override int ParameterCount { get; }
Property Value
Remarks
For Beginners: This tells you how many adjustable values (weights and biases) your neural network has. More complex networks typically have more parameters and can learn more complex patterns, but also require more data to train effectively. This is part of the IFullModel interface for consistency with other model types.
Performance: This property uses caching to avoid recomputing the sum on every access. The cache is invalidated when layers are modified.
Methods
AnswerQuestion(Tensor<T>, string, int)
public string AnswerQuestion(Tensor<T> image, string question, int maxLength = 64)
Parameters
Returns
Backward(Tensor<T>)
Backward pass through perceiver and gated cross-attention layers.
public Tensor<T> Backward(Tensor<T> gradient)
Parameters
gradientTensor<T>
Returns
- Tensor<T>
ComputeImageTextSimilarity(Tensor<T>, string)
public T ComputeImageTextSimilarity(Tensor<T> image, string text)
Parameters
imageTensor<T>textstring
Returns
- T
ComputeSimilarity(Vector<T>, Vector<T>)
Computes similarity between two embeddings.
public T ComputeSimilarity(Vector<T> embedding1, Vector<T> embedding2)
Parameters
embedding1Vector<T>The first embedding.
embedding2Vector<T>The second embedding.
Returns
- T
Similarity score (cosine similarity for normalized embeddings).
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DescribeVideo(IEnumerable<Tensor<T>>, string?, int)
Generates captions for a video represented as a sequence of frames.
public string DescribeVideo(IEnumerable<Tensor<T>> frames, string? prompt = null, int maxLength = 256)
Parameters
framesIEnumerable<Tensor<T>>Sequence of video frame tensors.
promptstringOptional prompt to guide generation.
maxLengthintMaximum tokens to generate.
Returns
- string
Generated video description.
Remarks
Flamingo can process multiple frames as separate images interleaved in context, enabling basic video understanding.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data that was not covered by the general deserialization process.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.
For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.
EncodeImage(double[])
Encodes an image into an embedding vector.
public Vector<T> EncodeImage(double[] imageData)
Parameters
imageDatadouble[]The preprocessed image data as a flattened array in CHW format.
Returns
- Vector<T>
A normalized embedding vector.
EncodeImageBatch(IEnumerable<double[]>)
Encodes multiple images into embedding vectors in a batch.
public Matrix<T> EncodeImageBatch(IEnumerable<double[]> imageDataBatch)
Parameters
imageDataBatchIEnumerable<double[]>The preprocessed images as flattened arrays.
Returns
- Matrix<T>
A matrix where each row is an embedding for the corresponding image.
EncodeText(string)
Encodes text into an embedding vector.
public Vector<T> EncodeText(string text)
Parameters
textstringThe text to encode.
Returns
- Vector<T>
A normalized embedding vector.
EncodeTextBatch(IEnumerable<string>)
Encodes multiple texts into embedding vectors in a batch.
public Matrix<T> EncodeTextBatch(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>The texts to encode.
Returns
- Matrix<T>
A matrix where each row is an embedding for the corresponding text.
ExtractPerceiverFeatures(Tensor<T>)
Extracts visual features using the Perceiver Resampler.
public Tensor<T> ExtractPerceiverFeatures(Tensor<T> image)
Parameters
imageTensor<T>The preprocessed image tensor.
Returns
- Tensor<T>
Resampled visual tokens with shape [numPerceiverTokens, hiddenDim].
Remarks
The Perceiver Resampler uses cross-attention with learnable queries to compress variable-length visual features into a fixed number of tokens.
FewShotGenerate(IEnumerable<(Tensor<T> Image, string Text)>, Tensor<T>, string?, int)
Performs few-shot visual learning with interleaved image-text examples.
public string FewShotGenerate(IEnumerable<(Tensor<T> Image, string Text)> examples, Tensor<T> queryImage, string? queryPrompt = null, int maxLength = 256)
Parameters
examplesIEnumerable<(Tensor<T> Image, string Text)>Few-shot examples as (image, text) pairs.
queryImageTensor<T>The new image to process.
queryPromptstringOptional prompt for the query (e.g., "What is this?").
maxLengthintMaximum tokens to generate.
Returns
- string
The generated response based on learned pattern.
Remarks
For Beginners: Learn a task from examples, then apply it!
Example - Learning to identify dog breeds: Examples:
- [image of labrador] "This is a Labrador Retriever"
- [image of poodle] "This is a Poodle"
- [image of beagle] "This is a Beagle"
Query: [image of golden retriever] "This is a..." Response: "Golden Retriever"
Flamingo learned the pattern from examples without any training!
FewShotImageRetrieval(IEnumerable<Tensor<T>>, string?, IEnumerable<Tensor<T>>, int)
Retrieves the most similar images from a database using few-shot context.
public IEnumerable<(int Index, T Score)> FewShotImageRetrieval(IEnumerable<Tensor<T>> queryExamples, string? queryDescription, IEnumerable<Tensor<T>> candidateImages, int topK = 10)
Parameters
queryExamplesIEnumerable<Tensor<T>>Example images representing what you're looking for.
queryDescriptionstringOptional text description of desired images.
candidateImagesIEnumerable<Tensor<T>>Database of images to search.
topKintNumber of results to return.
Returns
- IEnumerable<(int Index, T Score)>
Indices of most similar images with scores.
FewShotVQA(IEnumerable<(Tensor<T> Image, string Question, string Answer)>, Tensor<T>, string)
Performs visual question answering with few-shot examples.
public string FewShotVQA(IEnumerable<(Tensor<T> Image, string Question, string Answer)> examples, Tensor<T> queryImage, string question)
Parameters
examplesIEnumerable<(Tensor<T> Image, string Question, string Answer)>Example (image, question, answer) tuples.
queryImageTensor<T>The image to ask about.
questionstringThe question to answer.
Returns
- string
The generated answer.
GenerateCaption(Tensor<T>, int)
public string GenerateCaption(Tensor<T> image, int maxLength = 77)
Parameters
imageTensor<T>maxLengthint
Returns
GenerateWithMultipleImages(IEnumerable<Tensor<T>>, string, int)
Generates text for multiple images interleaved in a single context.
public string GenerateWithMultipleImages(IEnumerable<Tensor<T>> images, string prompt, int maxLength = 512)
Parameters
imagesIEnumerable<Tensor<T>>Sequence of images to process.
promptstringPrompt that may reference images using special tokens.
maxLengthintMaximum tokens to generate.
Returns
- string
Generated text response.
Remarks
Supports prompts like: "<image> shows a cat and <image> shows a dog. Compare them." where <image> tokens are replaced with corresponding image features.
GetImageEmbedding(Tensor<T>)
public Vector<T> GetImageEmbedding(Tensor<T> image)
Parameters
imageTensor<T>
Returns
- Vector<T>
GetImageEmbeddings(IEnumerable<Tensor<T>>)
public IEnumerable<Vector<T>> GetImageEmbeddings(IEnumerable<Tensor<T>> images)
Parameters
imagesIEnumerable<Tensor<T>>
Returns
- IEnumerable<Vector<T>>
GetModelMetadata()
Gets the metadata for this neural network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
GetParameters()
Gets all trainable parameters of the network as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all parameters of the network.
Remarks
For Beginners: Neural networks learn by adjusting their "parameters" (also called weights and biases). This method collects all those adjustable values into a single list so they can be updated during training.
GetTextEmbedding(string)
public Vector<T> GetTextEmbedding(string text)
Parameters
textstring
Returns
- Vector<T>
GetTextEmbeddings(IEnumerable<string>)
public IEnumerable<Vector<T>> GetTextEmbeddings(IEnumerable<string> texts)
Parameters
textsIEnumerable<string>
Returns
- IEnumerable<Vector<T>>
InContextClassify(IEnumerable<(Tensor<T> Image, string Label)>, Tensor<T>)
Performs in-context visual classification without explicit labels.
public Dictionary<string, T> InContextClassify(IEnumerable<(Tensor<T> Image, string Label)> labeledExamples, Tensor<T> queryImage)
Parameters
labeledExamplesIEnumerable<(Tensor<T> Image, string Text)>Examples with (image, label) pairs.
queryImageTensor<T>The image to classify.
Returns
- Dictionary<string, T>
Dictionary mapping labels to confidence scores.
Remarks
For Beginners: Classify images using just a few examples!
Instead of training a classifier on thousands of images:
- Show a few examples per class
- Flamingo learns the categories
- It can now classify new images
This is "few-shot classification" - works with any categories!
InitializeLayers()
Initializes the layers of the neural network based on the architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
RetrieveImages(string, IEnumerable<Vector<T>>, int)
public IEnumerable<(int Index, T Score)> RetrieveImages(string query, IEnumerable<Vector<T>> imageEmbeddings, int topK = 10)
Parameters
querystringimageEmbeddingsIEnumerable<Vector<T>>topKint
Returns
- IEnumerable<(int Index, T Score)>
RetrieveTexts(Tensor<T>, IEnumerable<string>, int)
public IEnumerable<(int Index, T Score)> RetrieveTexts(Tensor<T> image, IEnumerable<string> texts, int topK = 10)
Parameters
imageTensor<T>textsIEnumerable<string>topKint
Returns
- IEnumerable<(int Index, T Score)>
ScoreImageText(Tensor<T>, string)
Computes the log probability of a given text completion for an image.
public T ScoreImageText(Tensor<T> image, string text)
Parameters
imageTensor<T>The preprocessed image tensor.
textstringThe text to score.
Returns
- T
Log probability of the text given the image.
Remarks
Useful for ranking candidate captions or performing discriminative tasks with a generative model.
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data that is not covered by the general serialization process.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.
For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.
Train(Tensor<T>, Tensor<T>)
Trains the neural network on a single input-output pair.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input data.
expectedOutputTensor<T>The expected output for the given input.
Remarks
This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.
For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)
The network then:
- Makes a prediction based on the input
- Compares its prediction to the expected output
- Calculates how wrong it was (the loss)
- Adjusts its internal values to do better next time
After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.
UpdateParameters(Vector<T>)
Updates the network's parameters with new values.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameter values to set.
Remarks
For Beginners: During training, a neural network's internal values (parameters) get adjusted to improve its performance. This method allows you to update all those values at once by providing a complete set of new parameters.
This is typically used by optimization algorithms that calculate better parameter values based on training data.
ZeroShotClassify(Tensor<T>, IEnumerable<string>)
public Dictionary<string, T> ZeroShotClassify(Tensor<T> image, IEnumerable<string> classLabels)
Parameters
imageTensor<T>classLabelsIEnumerable<string>
Returns
- Dictionary<string, T>
ZeroShotClassify(double[], IEnumerable<string>)
Performs zero-shot classification of an image against text labels.
public Dictionary<string, T> ZeroShotClassify(double[] imageData, IEnumerable<string> labels)
Parameters
imageDatadouble[]The preprocessed image data.
labelsIEnumerable<string>The candidate class labels.
Returns
- Dictionary<string, T>
A dictionary mapping each label to its probability score.