Table of Contents

Class PrototypicalModel<T, TInput, TOutput>

Namespace
AiDotNet.MetaLearning.Algorithms
Assembly
AiDotNet.dll

Prototypical model for few-shot classification.

public class PrototypicalModel<T, TInput, TOutput> : IModel<TInput, TOutput, ModelMetadata<T>>

Type Parameters

T

The numeric type used for calculations.

TInput

The input data type.

TOutput

The output data type.

Inheritance
PrototypicalModel<T, TInput, TOutput>
Implements
IModel<TInput, TOutput, ModelMetadata<T>>
Inherited Members

Remarks

This model encapsulates the ProtoNets inference mechanism with pre-computed prototypes. It is returned by Adapt(IMetaLearningTask<T, TInput, TOutput>) and provides fast classification without any gradient computation.

For Beginners: After adapting ProtoNets to a new task, you get this model. It can classify new examples instantly by finding the nearest class prototype.

Constructors

PrototypicalModel(IFullModel<T, TInput, TOutput>, TInput, TOutput, ProtoNetsOptions<T, TInput, TOutput>, INumericOperations<T>)

Initializes a new instance of the PrototypicalModel.

public PrototypicalModel(IFullModel<T, TInput, TOutput> featureEncoder, TInput supportInputs, TOutput supportOutputs, ProtoNetsOptions<T, TInput, TOutput> options, INumericOperations<T> numOps)

Parameters

featureEncoder IFullModel<T, TInput, TOutput>

The trained feature encoder.

supportInputs TInput

Support set inputs for computing prototypes.

supportOutputs TOutput

Support set outputs (labels).

options ProtoNetsOptions<T, TInput, TOutput>

ProtoNets configuration options.

numOps INumericOperations<T>

Numeric operations for type T.

Properties

Metadata

Gets the model metadata.

public ModelMetadata<T> Metadata { get; }

Property Value

ModelMetadata<T>

Methods

GetModelMetadata()

Gets metadata about the model.

public ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

Model metadata including information about prototypes.

GetParameters()

Gets model parameters (not applicable for prototype-based models).

public Vector<T> GetParameters()

Returns

Vector<T>

Predict(TInput)

Makes predictions using prototype-based classification.

public TOutput Predict(TInput input)

Parameters

input TInput

The input to classify.

Returns

TOutput

Predicted class probabilities.

Train(TInput, TOutput)

Trains the model (not applicable for prototype-based models).

public void Train(TInput inputs, TOutput targets)

Parameters

inputs TInput
targets TOutput

UpdateParameters(Vector<T>)

Updates model parameters (not applicable for prototype-based models).

public void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>