Class PrototypicalModel<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
Prototypical model for few-shot classification.
public class PrototypicalModel<T, TInput, TOutput> : IModel<TInput, TOutput, ModelMetadata<T>>
Type Parameters
TThe numeric type used for calculations.
TInputThe input data type.
TOutputThe output data type.
- Inheritance
-
PrototypicalModel<T, TInput, TOutput>
- Implements
-
IModel<TInput, TOutput, ModelMetadata<T>>
- Inherited Members
Remarks
This model encapsulates the ProtoNets inference mechanism with pre-computed prototypes. It is returned by Adapt(IMetaLearningTask<T, TInput, TOutput>) and provides fast classification without any gradient computation.
For Beginners: After adapting ProtoNets to a new task, you get this model. It can classify new examples instantly by finding the nearest class prototype.
Constructors
PrototypicalModel(IFullModel<T, TInput, TOutput>, TInput, TOutput, ProtoNetsOptions<T, TInput, TOutput>, INumericOperations<T>)
Initializes a new instance of the PrototypicalModel.
public PrototypicalModel(IFullModel<T, TInput, TOutput> featureEncoder, TInput supportInputs, TOutput supportOutputs, ProtoNetsOptions<T, TInput, TOutput> options, INumericOperations<T> numOps)
Parameters
featureEncoderIFullModel<T, TInput, TOutput>The trained feature encoder.
supportInputsTInputSupport set inputs for computing prototypes.
supportOutputsTOutputSupport set outputs (labels).
optionsProtoNetsOptions<T, TInput, TOutput>ProtoNets configuration options.
numOpsINumericOperations<T>Numeric operations for type T.
Properties
Metadata
Gets the model metadata.
public ModelMetadata<T> Metadata { get; }
Property Value
Methods
GetModelMetadata()
Gets metadata about the model.
public ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
Model metadata including information about prototypes.
GetParameters()
Gets model parameters (not applicable for prototype-based models).
public Vector<T> GetParameters()
Returns
- Vector<T>
Predict(TInput)
Makes predictions using prototype-based classification.
public TOutput Predict(TInput input)
Parameters
inputTInputThe input to classify.
Returns
- TOutput
Predicted class probabilities.
Train(TInput, TOutput)
Trains the model (not applicable for prototype-based models).
public void Train(TInput inputs, TOutput targets)
Parameters
inputsTInputtargetsTOutput
UpdateParameters(Vector<T>)
Updates model parameters (not applicable for prototype-based models).
public void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>