Class TADAMAlgorithm<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
Implementation of Task-Dependent Adaptive Metric (TADAM) algorithm for few-shot learning.
public class TADAMAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
TInputThe input data type (e.g., Matrix<T>, Tensor<T>).
TOutputThe output data type (e.g., Vector<T>, Tensor<T>).
- Inheritance
-
MetaLearnerBase<T, TInput, TOutput>TADAMAlgorithm<T, TInput, TOutput>
- Implements
-
IMetaLearner<T, TInput, TOutput>
- Inherited Members
Remarks
TADAM extends Prototypical Networks by incorporating: 1. Task Conditioning (TC) using FiLM layers to modulate features 2. Metric Scaling to learn per-dimension distance weights 3. Auxiliary Co-Training for improved feature learning
For Beginners: TADAM improves on ProtoNets by making the feature extractor "aware" of the current task:
How it works:
- Extract features from support set examples
- Compute a "task embedding" summarizing what the task is about
- Use this task embedding to adjust (condition) how features are extracted
- Compute prototypes from the conditioned features
- Classify queries using scaled distances to prototypes
Key insight: Different tasks may require focusing on different features. TADAM learns to adjust what the network pays attention to based on the task.
Algorithm - TADAM:
# Components
f_theta = feature_encoder with FiLM layers
g_phi = task_encoder that produces task embedding
alpha = learnable metric scaling parameters
tau = learnable temperature
# Episode training
for each episode:
# 1. Compute task embedding from support set
support_features = f_theta(support_set) # Initial features
task_embedding = g_phi(mean(support_features))
# 2. Apply task conditioning via FiLM
gamma, beta = FiLM_generator(task_embedding)
conditioned_features = gamma * features + beta
# 3. Compute class prototypes
for each class c:
prototype_c = mean(conditioned_features[class == c])
# 4. Classify queries with scaled distance
query_features = f_theta(query_set, task_conditioning=True)
for each query q:
for each class c:
dist = sum(alpha * (query_features[q] - prototype_c)^2)
p(y=c|q) = softmax(-dist / tau)
# 5. Compute loss and update
loss = cross_entropy(p, true_labels)
if use_auxiliary:
loss += aux_weight * auxiliary_loss
update(f_theta, g_phi, alpha, tau)
Key Innovations:
Task Conditioning (TC): FiLM layers modulate feature maps based on task context. gamma and beta parameters are generated from the task embedding.
Metric Scaling: Learns per-dimension weights (alpha) for the distance metric, allowing the model to emphasize or de-emphasize different feature dimensions.
Learnable Temperature: The temperature tau controls softmax sharpness and is learned along with other parameters.
Auxiliary Co-Training: Optional auxiliary classification loss on base classes to improve feature learning.
Constructors
TADAMAlgorithm(TADAMOptions<T, TInput, TOutput>)
Initializes a new instance of the TADAMAlgorithm class.
public TADAMAlgorithm(TADAMOptions<T, TInput, TOutput> options)
Parameters
optionsTADAMOptions<T, TInput, TOutput>The configuration options for TADAM.
Exceptions
- ArgumentNullException
Thrown when options or required components are null.
- ArgumentException
Thrown when configuration validation fails.
Properties
AlgorithmType
Gets the type of meta-learning algorithm.
public override MetaLearningAlgorithmType AlgorithmType { get; }
Property Value
Methods
Adapt(IMetaLearningTask<T, TInput, TOutput>)
Adapts the model to a new task using its support set.
public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)
Parameters
taskIMetaLearningTask<T, TInput, TOutput>The task to adapt to.
Returns
- IModel<TInput, TOutput, ModelMetadata<T>>
A new model instance adapted to the task.
Remarks
For Beginners: This is where the "quick learning" happens. Given a new task with just a few examples (the support set), this method creates a new model that's specialized for that specific task.
MetaTrain(TaskBatch<T, TInput, TOutput>)
Performs one meta-training step on a batch of tasks.
public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)
Parameters
taskBatchTaskBatch<T, TInput, TOutput>The batch of tasks to train on.
Returns
- T
The meta-training loss for this batch.
Remarks
For Beginners: This method updates the model by training on multiple tasks at once. Each task teaches the model something about how to learn quickly. The returned loss value indicates how well the model is doing - lower is better.