Table of Contents

Class ProtoNetsAlgorithm<T, TInput, TOutput>

Namespace
AiDotNet.MetaLearning.Algorithms
Assembly
AiDotNet.dll

Implementation of Prototypical Networks (ProtoNets) algorithm for few-shot learning.

public class ProtoNetsAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>

Type Parameters

T

The numeric type used for calculations (e.g., float, double).

TInput

The input data type (e.g., Matrix<T>, Tensor<T>).

TOutput

The output data type (e.g., Vector<T>, Tensor<T>).

Inheritance
MetaLearnerBase<T, TInput, TOutput>
ProtoNetsAlgorithm<T, TInput, TOutput>
Implements
IMetaLearner<T, TInput, TOutput>
Inherited Members

Remarks

Prototypical Networks learn a metric space where classification can be performed by computing distances to prototype representations of each class. Each prototype is the mean vector of the support set examples for that class.

For Beginners: ProtoNets learns to recognize new classes from just a few examples:

How it works:

  1. For each new class, create a "prototype" (average of all examples)
  2. To classify a new example, find which prototype is closest
  3. Distance is measured in a learned feature space
  4. Uses soft nearest neighbor with learnable distance metric

Simple example:

  • Support set: 3 images each of 5 different animal species (15 images total)
  • Create prototype for each species by averaging their features
  • Query image: classify by finding nearest animal prototype
  • Learning: train encoder to make same-species images cluster together

Algorithm - Prototypical Networks:

# Encoding phase (learnable)
feature_encoder = NeuralNetwork()  # Maps x -> embedding(x)

# Episode training
for each episode:
    # Sample N-way K-shot task
    support_set = {examples_from_N_classes, K_examples_each}
    query_set = {examples_from_same_N_classes}

    # Compute class prototypes (non-parametric)
    for each class c:
        prototype_c = mean(embedding(x) for x in support_examples_of_class_c)

    # Classification by distance
    for each query example x:
        distances = [distance(embedding(x), prototype_c) for c in classes]
        probabilities = softmax(-distances)
        loss = cross_entropy(probabilities, true_label)

    # Update encoder (no prototypes to store!)
    backpropagate(loss)
    update(feature_encoder.parameters)

Key Insights:

  1. Non-parametric Classification: No classifier parameters to learn, just need a good feature encoder. Prototypes are computed on-the-fly.

  2. Metric Learning: The encoder learns to cluster same-class examples and separate different classes in the feature space.

  3. Efficient Adaptation: To adapt to new classes, just compute new prototypes - no gradient updates needed!

  4. Interpretable: Prototypes provide an intuitive representation of each class as the "average example".

Reference: Snell, J., Swersky, K., & Zemel, R. (2017). Prototypical Networks for Few-shot Learning.

Constructors

ProtoNetsAlgorithm(ProtoNetsOptions<T, TInput, TOutput>)

Initializes a new instance of the ProtoNetsAlgorithm class.

public ProtoNetsAlgorithm(ProtoNetsOptions<T, TInput, TOutput> options)

Parameters

options ProtoNetsOptions<T, TInput, TOutput>

The configuration options for ProtoNets.

Remarks

For Beginners: This creates a ProtoNets model ready for few-shot learning.

What ProtoNets needs:

  • MetaModel: Neural network that maps inputs to features (e.g., CNN for images)
  • DistanceFunction: How to measure similarity (Euclidean, Cosine, etc.)

What happens during training:

  1. Sample episodes with N classes, K examples each
  2. Compute prototypes by averaging features
  3. Train encoder to make same-class features close
  4. Test on query set from same classes

What happens during testing:

  1. Get K examples of each new class
  2. Compute prototypes (no training needed!)
  3. Classify new examples by nearest prototype

Exceptions

ArgumentNullException

Thrown when options or required components are null.

ArgumentException

Thrown when configuration validation fails.

Properties

AlgorithmType

Gets the algorithm type identifier for this meta-learner.

public override MetaLearningAlgorithmType AlgorithmType { get; }

Property Value

MetaLearningAlgorithmType

Returns ProtoNets.

Methods

Adapt(IMetaLearningTask<T, TInput, TOutput>)

Adapts to a new task by computing class prototypes from the support set.

public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)

Parameters

task IMetaLearningTask<T, TInput, TOutput>

The new task containing support set examples.

Returns

IModel<TInput, TOutput, ModelMetadata<T>>

A PrototypicalModel that classifies by nearest prototype.

Remarks

This is where ProtoNets shines - adaptation is instantaneous! Unlike MAML which requires gradient descent steps, ProtoNets just computes prototypes from the support set.

Adaptation Process: 1. Encode all support set examples using the trained feature encoder 2. Group embeddings by class 3. Compute prototype for each class (mean of class embeddings) 4. Return a model that classifies by distance to prototypes

For Beginners: After meta-training, when you have a new task with labeled examples, call this method. The returned model can immediately classify new examples by finding the nearest class prototype - no additional training needed!

Exceptions

ArgumentNullException

Thrown when task is null.

MetaTrain(TaskBatch<T, TInput, TOutput>)

Performs one meta-training step using ProtoNets' episodic training.

public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)

Parameters

taskBatch TaskBatch<T, TInput, TOutput>

A batch of tasks to meta-train on.

Returns

T

The average meta-loss across all tasks in the batch.

Remarks

ProtoNets training is simpler than MAML because there's no inner loop gradient computation:

For each task in the batch: 1. Encode support set examples to get feature embeddings 2. Compute class prototypes (mean of each class's embeddings) 3. Encode query set examples 4. Compute distances from query embeddings to prototypes 5. Apply softmax to get class probabilities 6. Compute cross-entropy loss

Meta-update: Average losses across all tasks and backpropagate to update the feature encoder.

Exceptions

ArgumentException

Thrown when the task batch is null or empty.