Class SSLFineTuningPipeline<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Pipeline for fine-tuning SSL pretrained encoders on downstream tasks.
public class SSLFineTuningPipeline<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
SSLFineTuningPipeline<T>
- Inherited Members
Remarks
For Beginners: After SSL pretraining, you typically want to fine-tune the encoder on a specific task with labeled data. This pipeline handles the fine-tuning process with proper learning rate schedules and evaluation.
Fine-tuning strategies:
- Full fine-tuning: Update all parameters
- Linear probing: Freeze encoder, train only classifier
- Gradual unfreezing: Unfreeze layers progressively
Constructors
SSLFineTuningPipeline(INeuralNetwork<T>, int, int)
Initializes a new fine-tuning pipeline.
public SSLFineTuningPipeline(INeuralNetwork<T> encoder, int encoderOutputDim, int numClasses)
Parameters
encoderINeuralNetwork<T>Pretrained encoder to fine-tune.
encoderOutputDimintOutput dimension of the encoder.
numClassesintNumber of classes for classification.
Methods
Evaluate(Tensor<T>, int[])
Evaluates the model on test data.
public double Evaluate(Tensor<T> testData, int[] testLabels)
Parameters
testDataTensor<T>testLabelsint[]
Returns
FineTune(Tensor<T>, int[], Tensor<T>?, int[]?)
Fine-tunes the model on labeled data.
public FineTuningResult<T> FineTune(Tensor<T> trainData, int[] trainLabels, Tensor<T>? validData = null, int[]? validLabels = null)
Parameters
trainDataTensor<T>Training data.
trainLabelsint[]Training labels.
validDataTensor<T>Optional validation data.
validLabelsint[]Optional validation labels.
Returns
- FineTuningResult<T>
Fine-tuning result with accuracy.
WithConfig(Action<FineTuningConfig>)
Configures fine-tuning parameters.
public SSLFineTuningPipeline<T> WithConfig(Action<FineTuningConfig> configure)
Parameters
configureAction<FineTuningConfig>
Returns
WithStrategy(FineTuningStrategy)
Sets the fine-tuning strategy.
public SSLFineTuningPipeline<T> WithStrategy(FineTuningStrategy strategy)
Parameters
strategyFineTuningStrategy
Returns
Events
OnProgress
Event raised for progress updates.
public event Action<int, int, double>? OnProgress