Class RelationNetworkAlgorithm<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
Implementation of Relation Networks algorithm for few-shot learning.
public class RelationNetworkAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
TInputThe input data type (e.g., Matrix<T>, Tensor<T>).
TOutputThe output data type (e.g., Vector<T>, Tensor<T>).
- Inheritance
-
MetaLearnerBase<T, TInput, TOutput>RelationNetworkAlgorithm<T, TInput, TOutput>
- Implements
-
IMetaLearner<T, TInput, TOutput>
- Inherited Members
Remarks
Relation Networks learn to compare query examples with class examples by learning a relation function that measures similarity. Unlike metric learning approaches that use fixed distance functions, Relation Networks learn the relation function end-to-end.
For Beginners: Relation Networks learns how to compare examples:
How it works:
- Encode all examples (support and query) with a feature encoder
- For each query, concatenate with each support example's features
- Pass concatenated features through a relation module (neural network)
- The relation module outputs a similarity score
- Apply softmax to get class probabilities
Key insight: Instead of using predefined distances (like Euclidean), it learns a neural network to measure "how related" two examples are.
Algorithm - Relation Networks:
# Learn two networks
feature_encoder = CNN() # Maps x -> phi(x)
relation_module = MLP() # Maps [phi(x_i), phi(x_j)] -> similarity
# Episode training
for each episode:
# Sample N-way K-shot task
support_set = {examples_from_N_classes, K_examples_each}
query_set = {examples_from_same_N_classes}
# Encode all examples
support_features = [feature_encoder(x) for x in support_set]
query_features = [feature_encoder(x) for x in query_set]
# Compute relation scores
for each query example q:
scores = []
for each class c:
class_score = 0
for each support example s in class c:
# Concatenate and compute relation
combined = concatenate(phi(q), phi(s))
relation_score = relation_module(combined)
class_score += relation_score
scores.append(average(class_score))
probabilities = softmax(scores)
loss = cross_entropy(probabilities, true_label)
# Update both networks
backpropagate(loss)
update(feature_encoder, relation_module)
Key Insights:
Learnable Relation Function: Instead of fixed distances, learns a neural network to measure similarity. Can capture complex, non-linear relations.
End-to-End Training: Both feature encoder and relation module are trained jointly, optimizing for the final classification task.
Flexible Relations: The relation module can learn to attend to specific features, ignore noise, and detect subtle patterns.
Scalable Complexity: More powerful relation modules can handle more complex tasks at the cost of computation.
Constructors
RelationNetworkAlgorithm(RelationNetworkOptions<T, TInput, TOutput>)
Initializes a new instance of the RelationNetworkAlgorithm class.
public RelationNetworkAlgorithm(RelationNetworkOptions<T, TInput, TOutput> options)
Parameters
optionsRelationNetworkOptions<T, TInput, TOutput>The configuration options for Relation Networks.
Exceptions
- ArgumentNullException
Thrown when options or required components are null.
- ArgumentException
Thrown when configuration validation fails.
Properties
AlgorithmType
Gets the type of meta-learning algorithm.
public override MetaLearningAlgorithmType AlgorithmType { get; }
Property Value
Methods
Adapt(IMetaLearningTask<T, TInput, TOutput>)
Adapts the model to a new task using its support set.
public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)
Parameters
taskIMetaLearningTask<T, TInput, TOutput>The task to adapt to.
Returns
- IModel<TInput, TOutput, ModelMetadata<T>>
A new model instance adapted to the task.
Remarks
For Beginners: This is where the "quick learning" happens. Given a new task with just a few examples (the support set), this method creates a new model that's specialized for that specific task.
MetaTrain(TaskBatch<T, TInput, TOutput>)
Performs one meta-training step on a batch of tasks.
public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)
Parameters
taskBatchTaskBatch<T, TInput, TOutput>The batch of tasks to train on.
Returns
- T
The meta-training loss for this batch.
Remarks
For Beginners: This method updates the model by training on multiple tasks at once. Each task teaches the model something about how to learn quickly. The returned loss value indicates how well the model is doing - lower is better.