Class MatchingNetworksAlgorithm<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
Implementation of Matching Networks for few-shot learning.
public class MatchingNetworksAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
TInputThe input data type (e.g., Matrix<T>, Tensor<T>).
TOutputThe output data type (e.g., Vector<T>, Tensor<T>).
- Inheritance
-
MetaLearnerBase<T, TInput, TOutput>MatchingNetworksAlgorithm<T, TInput, TOutput>
- Implements
-
IMetaLearner<T, TInput, TOutput>
- Inherited Members
Remarks
Matching Networks use attention mechanisms over the support set to classify query examples. It computes a weighted sum of support labels where weights are determined by an attention function that measures similarity between examples.
For Beginners: Matching Networks learn to pay attention to similar examples:
How it works:
- Encode all examples (support and query) with a shared encoder
- For each query, compute attention weights with all support examples
- Use cosine similarity or learned attention for weights
- Predict weighted sum of support labels (soft nearest neighbor)
Key insight: The network learns how to compare examples during encoding, making the similarity measure task-aware.
Algorithm - Matching Networks:
# Shared encoder that learns to produce comparable embeddings
encoder = NeuralNetwork() # Maps x -> embedding(x)
attention_function = cosine_similarity or learned_attention
# Episode training
for each episode:
# Sample N-way K-shot task
support_set = {examples_from_N_classes, K_examples_each}
query_set = {examples_from_same_N_classes}
# Encode all examples
support_embeddings = encoder(support_set)
query_embeddings = encoder(query_set)
# For each query:
for each query q:
# Compute attention with all support examples
scores = [attention(embedding(q), embedding(s)) for s in support]
weights = softmax(scores)
prediction = sum(weight_i * label_i for support examples)
# Train with cross-entropy loss
Key Insights:
Task-Aware Embeddings: The encoder learns to produce embeddings that are meaningful for the specific classification task at hand.
Differentiable Attention: The attention mechanism is fully differentiable, allowing end-to-end training of the encoder.
No Adaptation Needed: At test time, simply encode new examples and apply the same attention mechanism - no gradient updates required.
Reference: Vinyals, O., Blundell, C., Lillicrap, T., Kavukcuoglu, K., & Wierstra, D. (2016). Matching Networks for One Shot Learning. NeurIPS.
Constructors
MatchingNetworksAlgorithm(MatchingNetworksOptions<T, TInput, TOutput>)
Initializes a new instance of the MatchingNetworksAlgorithm class.
public MatchingNetworksAlgorithm(MatchingNetworksOptions<T, TInput, TOutput> options)
Parameters
optionsMatchingNetworksOptions<T, TInput, TOutput>The configuration options for Matching Networks.
Remarks
For Beginners: This creates a Matching Network ready for few-shot learning.
What Matching Networks need:
- encoder: Neural network that embeds examples into comparable space
- attentionFunction: How to measure similarity (cosine, learned)
- useFullContext: Whether to use all examples when encoding each example
What makes it different from ProtoNets:
- ProtoNets: Uses fixed distance (Euclidean) to class prototypes (mean)
- Matching Nets: Uses learnable attention to individual support examples
- No explicit class representatives - each example votes
Exceptions
- ArgumentNullException
Thrown when options or required components are null.
- ArgumentException
Thrown when configuration validation fails.
Properties
AlgorithmType
Gets the algorithm type identifier for this meta-learner.
public override MetaLearningAlgorithmType AlgorithmType { get; }
Property Value
Methods
Adapt(IMetaLearningTask<T, TInput, TOutput>)
Adapts to a new task by caching support embeddings.
public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)
Parameters
taskIMetaLearningTask<T, TInput, TOutput>The new task containing support set examples.
Returns
- IModel<TInput, TOutput, ModelMetadata<T>>
A MatchingNetworksModel that classifies using attention over support examples.
Remarks
Matching Networks adaptation is very fast - just encode support examples and cache them.
For Beginners: After meta-training, when you have a new task with labeled examples, call this method. The returned model can classify new examples by comparing them to all support examples using learned attention.
Exceptions
- ArgumentNullException
Thrown when task is null.
MetaTrain(TaskBatch<T, TInput, TOutput>)
Performs one meta-training step using Matching Networks' episodic training.
public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)
Parameters
taskBatchTaskBatch<T, TInput, TOutput>A batch of tasks to meta-train on.
Returns
- T
The average meta-loss across all tasks in the batch.
Remarks
Matching Networks training computes attention-based predictions:
For each task in the batch: 1. Encode all support and query examples 2. For each query, compute attention weights with all support examples 3. Predict weighted sum of support labels 4. Compute cross-entropy loss 5. Update encoder with averaged gradients
Exceptions
- ArgumentException
Thrown when the task batch is null or empty.