Class ProtoNetsAlgorithm<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
Implementation of Prototypical Networks (ProtoNets) algorithm for few-shot learning.
public class ProtoNetsAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
TInputThe input data type (e.g., Matrix<T>, Tensor<T>).
TOutputThe output data type (e.g., Vector<T>, Tensor<T>).
- Inheritance
-
MetaLearnerBase<T, TInput, TOutput>ProtoNetsAlgorithm<T, TInput, TOutput>
- Implements
-
IMetaLearner<T, TInput, TOutput>
- Inherited Members
Remarks
Prototypical Networks learn a metric space where classification can be performed by computing distances to prototype representations of each class. Each prototype is the mean vector of the support set examples for that class.
For Beginners: ProtoNets learns to recognize new classes from just a few examples:
How it works:
- For each new class, create a "prototype" (average of all examples)
- To classify a new example, find which prototype is closest
- Distance is measured in a learned feature space
- Uses soft nearest neighbor with learnable distance metric
Simple example:
- Support set: 3 images each of 5 different animal species (15 images total)
- Create prototype for each species by averaging their features
- Query image: classify by finding nearest animal prototype
- Learning: train encoder to make same-species images cluster together
Algorithm - Prototypical Networks:
# Encoding phase (learnable)
feature_encoder = NeuralNetwork() # Maps x -> embedding(x)
# Episode training
for each episode:
# Sample N-way K-shot task
support_set = {examples_from_N_classes, K_examples_each}
query_set = {examples_from_same_N_classes}
# Compute class prototypes (non-parametric)
for each class c:
prototype_c = mean(embedding(x) for x in support_examples_of_class_c)
# Classification by distance
for each query example x:
distances = [distance(embedding(x), prototype_c) for c in classes]
probabilities = softmax(-distances)
loss = cross_entropy(probabilities, true_label)
# Update encoder (no prototypes to store!)
backpropagate(loss)
update(feature_encoder.parameters)
Key Insights:
Non-parametric Classification: No classifier parameters to learn, just need a good feature encoder. Prototypes are computed on-the-fly.
Metric Learning: The encoder learns to cluster same-class examples and separate different classes in the feature space.
Efficient Adaptation: To adapt to new classes, just compute new prototypes - no gradient updates needed!
Interpretable: Prototypes provide an intuitive representation of each class as the "average example".
Reference: Snell, J., Swersky, K., & Zemel, R. (2017). Prototypical Networks for Few-shot Learning.
Constructors
ProtoNetsAlgorithm(ProtoNetsOptions<T, TInput, TOutput>)
Initializes a new instance of the ProtoNetsAlgorithm class.
public ProtoNetsAlgorithm(ProtoNetsOptions<T, TInput, TOutput> options)
Parameters
optionsProtoNetsOptions<T, TInput, TOutput>The configuration options for ProtoNets.
Remarks
For Beginners: This creates a ProtoNets model ready for few-shot learning.
What ProtoNets needs:
- MetaModel: Neural network that maps inputs to features (e.g., CNN for images)
- DistanceFunction: How to measure similarity (Euclidean, Cosine, etc.)
What happens during training:
- Sample episodes with N classes, K examples each
- Compute prototypes by averaging features
- Train encoder to make same-class features close
- Test on query set from same classes
What happens during testing:
- Get K examples of each new class
- Compute prototypes (no training needed!)
- Classify new examples by nearest prototype
Exceptions
- ArgumentNullException
Thrown when options or required components are null.
- ArgumentException
Thrown when configuration validation fails.
Properties
AlgorithmType
Gets the algorithm type identifier for this meta-learner.
public override MetaLearningAlgorithmType AlgorithmType { get; }
Property Value
- MetaLearningAlgorithmType
Returns ProtoNets.
Methods
Adapt(IMetaLearningTask<T, TInput, TOutput>)
Adapts to a new task by computing class prototypes from the support set.
public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)
Parameters
taskIMetaLearningTask<T, TInput, TOutput>The new task containing support set examples.
Returns
- IModel<TInput, TOutput, ModelMetadata<T>>
A PrototypicalModel that classifies by nearest prototype.
Remarks
This is where ProtoNets shines - adaptation is instantaneous! Unlike MAML which requires gradient descent steps, ProtoNets just computes prototypes from the support set.
Adaptation Process: 1. Encode all support set examples using the trained feature encoder 2. Group embeddings by class 3. Compute prototype for each class (mean of class embeddings) 4. Return a model that classifies by distance to prototypes
For Beginners: After meta-training, when you have a new task with labeled examples, call this method. The returned model can immediately classify new examples by finding the nearest class prototype - no additional training needed!
Exceptions
- ArgumentNullException
Thrown when task is null.
MetaTrain(TaskBatch<T, TInput, TOutput>)
Performs one meta-training step using ProtoNets' episodic training.
public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)
Parameters
taskBatchTaskBatch<T, TInput, TOutput>A batch of tasks to meta-train on.
Returns
- T
The average meta-loss across all tasks in the batch.
Remarks
ProtoNets training is simpler than MAML because there's no inner loop gradient computation:
For each task in the batch: 1. Encode support set examples to get feature embeddings 2. Compute class prototypes (mean of each class's embeddings) 3. Encode query set examples 4. Compute distances from query embeddings to prototypes 5. Apply softmax to get class probabilities 6. Compute cross-entropy loss
Meta-update: Average losses across all tasks and backpropagate to update the feature encoder.
Exceptions
- ArgumentException
Thrown when the task batch is null or empty.