Class SSLPretrainingPipeline<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
High-level pipeline for SSL pretraining.
public class SSLPretrainingPipeline<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
SSLPretrainingPipeline<T>
- Inherited Members
Remarks
For Beginners: This pipeline provides a simple, high-level interface for SSL pretraining. Just provide your encoder and data, and it handles the rest: method selection, augmentation, training loop, and evaluation.
Example usage:
var pipeline = new SSLPretrainingPipeline<double>(encoder)
.WithMethod(SSLMethodType.SimCLR)
.WithConfig(config => config.PretrainingEpochs = 100);
var result = pipeline.Train(dataLoader);
Constructors
SSLPretrainingPipeline(INeuralNetwork<T>, int)
Initializes a new SSL pretraining pipeline.
public SSLPretrainingPipeline(INeuralNetwork<T> encoder, int encoderOutputDim)
Parameters
encoderINeuralNetwork<T>The encoder network to pretrain.
encoderOutputDimintOutput dimension of the encoder.
Methods
Train(Func<IEnumerable<Tensor<T>>>, Tensor<T>?, int[]?)
Trains the encoder using SSL.
public SSLResult<T> Train(Func<IEnumerable<Tensor<T>>> dataLoader, Tensor<T>? validationData = null, int[]? validationLabels = null)
Parameters
dataLoaderFunc<IEnumerable<Tensor<T>>>Function that yields batches of unlabeled data.
validationDataTensor<T>Optional validation data for monitoring.
validationLabelsint[]Optional validation labels for k-NN evaluation.
Returns
- SSLResult<T>
Training result with pretrained encoder.
WithConfig(Action<SSLConfig>)
Configures the training parameters.
public SSLPretrainingPipeline<T> WithConfig(Action<SSLConfig> configure)
Parameters
Returns
WithEncoderCopyFactory(Func<INeuralNetwork<T>, INeuralNetwork<T>>)
Sets the function to create encoder copies (for momentum methods).
public SSLPretrainingPipeline<T> WithEncoderCopyFactory(Func<INeuralNetwork<T>, INeuralNetwork<T>> createCopy)
Parameters
createCopyFunc<INeuralNetwork<T>, INeuralNetwork<T>>
Returns
WithMethod(SSLMethodType)
Sets the SSL method to use.
public SSLPretrainingPipeline<T> WithMethod(SSLMethodType method)
Parameters
methodSSLMethodType
Returns
Events
OnProgress
Event raised during training for progress updates.
public event Action<int, int, T>? OnProgress