Class SSLSession<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Manages a self-supervised learning training session.
public class SSLSession<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
SSLSession<T>
- Inherited Members
Remarks
For Beginners: An SSL session manages the entire training lifecycle: initialization, training loop, evaluation, and checkpointing. It provides callbacks for monitoring progress and supports resuming from checkpoints.
Constructors
SSLSession(ISSLMethod<T>, SSLConfig?)
Initializes a new SSL training session.
public SSLSession(ISSLMethod<T> method, SSLConfig? config = null)
Parameters
methodISSLMethod<T>The SSL method to use.
configSSLConfigTraining configuration.
Properties
CurrentEpoch
Gets the current epoch number.
public int CurrentEpoch { get; }
Property Value
GlobalStep
Gets the global step counter.
public int GlobalStep { get; }
Property Value
IsDistributed
Gets whether this session is using distributed training.
public bool IsDistributed { get; }
Property Value
IsTraining
Gets whether training is in progress.
public bool IsTraining { get; }
Property Value
Method
Gets the SSL method being used.
public ISSLMethod<T> Method { get; }
Property Value
- ISSLMethod<T>
Rank
Gets the rank of this worker in distributed training.
public int Rank { get; }
Property Value
WorldSize
Gets the world size (number of workers) for distributed training.
public int WorldSize { get; }
Property Value
Methods
CacheTrainingFeaturesForKNN(IEnumerable<Tensor<T>>, int[])
Caches training data for k-NN evaluation.
public void CacheTrainingFeaturesForKNN(IEnumerable<Tensor<T>> trainingData, int[] trainingLabels)
Parameters
trainingDataIEnumerable<Tensor<T>>Training data batches.
trainingLabelsint[]Corresponding labels.
Evaluate(Tensor<T>)
Runs evaluation on the current encoder.
public SSLMetricReport<T> Evaluate(Tensor<T> data)
Parameters
dataTensor<T>
Returns
FromCheckpoint(string, INeuralNetwork<T>, Func<INeuralNetwork<T>, ISSLMethod<T>>)
Creates a session from a pretrained checkpoint.
public static SSLSession<T> FromCheckpoint(string checkpointPath, INeuralNetwork<T> encoder, Func<INeuralNetwork<T>, ISSLMethod<T>> methodFactory)
Parameters
checkpointPathstringPath to the checkpoint file.
encoderINeuralNetwork<T>The encoder network to load weights into.
methodFactoryFunc<INeuralNetwork<T>, ISSLMethod<T>>Factory to create the SSL method from the encoder.
Returns
- SSLSession<T>
A session restored from the checkpoint.
GetEffectiveBatchSize()
Gets the effective batch size considering distributed training.
public int GetEffectiveBatchSize()
Returns
Remarks
For DDP, the effective batch size is local_batch_size * world_size. This is important for correctly scaling the learning rate.
GetHistory()
Gets the current training history.
public SSLTrainingHistory<T> GetHistory()
Returns
Reset()
Resets the session for a new training run.
public void Reset()
SaveCheckpoint(string)
Saves a checkpoint to disk.
public void SaveCheckpoint(string checkpointPath)
Parameters
checkpointPathstringPath to save the checkpoint.
Stop()
Stops training gracefully.
public void Stop()
SynchronizeParameters()
Synchronizes model parameters across all distributed workers.
public void SynchronizeParameters()
Remarks
Call this method to ensure all workers have identical parameters. This is useful at initialization or after loading checkpoints.
Train(Func<IEnumerable<Tensor<T>>>, Tensor<T>?, int[]?)
Trains the SSL method for the specified number of epochs.
public SSLResult<T> Train(Func<IEnumerable<Tensor<T>>> dataLoader, Tensor<T>? validationData = null, int[]? validationLabels = null)
Parameters
dataLoaderFunc<IEnumerable<Tensor<T>>>Function that yields batches of data.
validationDataTensor<T>Optional validation data for k-NN evaluation.
validationLabelsint[]Optional validation labels.
Returns
- SSLResult<T>
Training result.
Events
OnCollapseDetected
Event raised when collapse is detected.
public event Action<int>? OnCollapseDetected
Event Type
OnEpochEnd
Event raised at the end of each epoch.
public event Action<int, T>? OnEpochEnd
Event Type
OnEpochStart
Event raised at the start of each epoch.
public event Action<int>? OnEpochStart
Event Type
OnStepComplete
Event raised after each training step.
public event Action<int, SSLStepResult<T>>? OnStepComplete
Event Type
- Action<int, SSLStepResult<T>>