Table of Contents

Class RelationNetworkAlgorithm<T, TInput, TOutput>

Namespace
AiDotNet.MetaLearning.Algorithms
Assembly
AiDotNet.dll

Implementation of Relation Networks algorithm for few-shot learning.

public class RelationNetworkAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>

Type Parameters

T

The numeric type used for calculations (e.g., float, double).

TInput

The input data type (e.g., Matrix<T>, Tensor<T>).

TOutput

The output data type (e.g., Vector<T>, Tensor<T>).

Inheritance
MetaLearnerBase<T, TInput, TOutput>
RelationNetworkAlgorithm<T, TInput, TOutput>
Implements
IMetaLearner<T, TInput, TOutput>
Inherited Members

Remarks

Relation Networks learn to compare query examples with class examples by learning a relation function that measures similarity. Unlike metric learning approaches that use fixed distance functions, Relation Networks learn the relation function end-to-end.

For Beginners: Relation Networks learns how to compare examples:

How it works:

  1. Encode all examples (support and query) with a feature encoder
  2. For each query, concatenate with each support example's features
  3. Pass concatenated features through a relation module (neural network)
  4. The relation module outputs a similarity score
  5. Apply softmax to get class probabilities

Key insight: Instead of using predefined distances (like Euclidean), it learns a neural network to measure "how related" two examples are.

Algorithm - Relation Networks:

# Learn two networks
feature_encoder = CNN()         # Maps x -> phi(x)
relation_module = MLP()        # Maps [phi(x_i), phi(x_j)] -> similarity

# Episode training
for each episode:
    # Sample N-way K-shot task
    support_set = {examples_from_N_classes, K_examples_each}
    query_set = {examples_from_same_N_classes}

    # Encode all examples
    support_features = [feature_encoder(x) for x in support_set]
    query_features = [feature_encoder(x) for x in query_set]

    # Compute relation scores
    for each query example q:
        scores = []
        for each class c:
            class_score = 0
            for each support example s in class c:
                # Concatenate and compute relation
                combined = concatenate(phi(q), phi(s))
                relation_score = relation_module(combined)
                class_score += relation_score
            scores.append(average(class_score))
        probabilities = softmax(scores)
        loss = cross_entropy(probabilities, true_label)

    # Update both networks
    backpropagate(loss)
    update(feature_encoder, relation_module)

Key Insights:

  1. Learnable Relation Function: Instead of fixed distances, learns a neural network to measure similarity. Can capture complex, non-linear relations.

  2. End-to-End Training: Both feature encoder and relation module are trained jointly, optimizing for the final classification task.

  3. Flexible Relations: The relation module can learn to attend to specific features, ignore noise, and detect subtle patterns.

  4. Scalable Complexity: More powerful relation modules can handle more complex tasks at the cost of computation.

Constructors

RelationNetworkAlgorithm(RelationNetworkOptions<T, TInput, TOutput>)

Initializes a new instance of the RelationNetworkAlgorithm class.

public RelationNetworkAlgorithm(RelationNetworkOptions<T, TInput, TOutput> options)

Parameters

options RelationNetworkOptions<T, TInput, TOutput>

The configuration options for Relation Networks.

Exceptions

ArgumentNullException

Thrown when options or required components are null.

ArgumentException

Thrown when configuration validation fails.

Properties

AlgorithmType

Gets the type of meta-learning algorithm.

public override MetaLearningAlgorithmType AlgorithmType { get; }

Property Value

MetaLearningAlgorithmType

Methods

Adapt(IMetaLearningTask<T, TInput, TOutput>)

Adapts the model to a new task using its support set.

public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)

Parameters

task IMetaLearningTask<T, TInput, TOutput>

The task to adapt to.

Returns

IModel<TInput, TOutput, ModelMetadata<T>>

A new model instance adapted to the task.

Remarks

For Beginners: This is where the "quick learning" happens. Given a new task with just a few examples (the support set), this method creates a new model that's specialized for that specific task.

MetaTrain(TaskBatch<T, TInput, TOutput>)

Performs one meta-training step on a batch of tasks.

public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)

Parameters

taskBatch TaskBatch<T, TInput, TOutput>

The batch of tasks to train on.

Returns

T

The meta-training loss for this batch.

Remarks

For Beginners: This method updates the model by training on multiple tasks at once. Each task teaches the model something about how to learn quickly. The returned loss value indicates how well the model is doing - lower is better.