Table of Contents

Class SSLSession<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Manages a self-supervised learning training session.

public class SSLSession<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
SSLSession<T>
Inherited Members

Remarks

For Beginners: An SSL session manages the entire training lifecycle: initialization, training loop, evaluation, and checkpointing. It provides callbacks for monitoring progress and supports resuming from checkpoints.

Constructors

SSLSession(ISSLMethod<T>, SSLConfig?)

Initializes a new SSL training session.

public SSLSession(ISSLMethod<T> method, SSLConfig? config = null)

Parameters

method ISSLMethod<T>

The SSL method to use.

config SSLConfig

Training configuration.

Properties

CurrentEpoch

Gets the current epoch number.

public int CurrentEpoch { get; }

Property Value

int

GlobalStep

Gets the global step counter.

public int GlobalStep { get; }

Property Value

int

IsDistributed

Gets whether this session is using distributed training.

public bool IsDistributed { get; }

Property Value

bool

IsTraining

Gets whether training is in progress.

public bool IsTraining { get; }

Property Value

bool

Method

Gets the SSL method being used.

public ISSLMethod<T> Method { get; }

Property Value

ISSLMethod<T>

Rank

Gets the rank of this worker in distributed training.

public int Rank { get; }

Property Value

int

WorldSize

Gets the world size (number of workers) for distributed training.

public int WorldSize { get; }

Property Value

int

Methods

CacheTrainingFeaturesForKNN(IEnumerable<Tensor<T>>, int[])

Caches training data for k-NN evaluation.

public void CacheTrainingFeaturesForKNN(IEnumerable<Tensor<T>> trainingData, int[] trainingLabels)

Parameters

trainingData IEnumerable<Tensor<T>>

Training data batches.

trainingLabels int[]

Corresponding labels.

Evaluate(Tensor<T>)

Runs evaluation on the current encoder.

public SSLMetricReport<T> Evaluate(Tensor<T> data)

Parameters

data Tensor<T>

Returns

SSLMetricReport<T>

FromCheckpoint(string, INeuralNetwork<T>, Func<INeuralNetwork<T>, ISSLMethod<T>>)

Creates a session from a pretrained checkpoint.

public static SSLSession<T> FromCheckpoint(string checkpointPath, INeuralNetwork<T> encoder, Func<INeuralNetwork<T>, ISSLMethod<T>> methodFactory)

Parameters

checkpointPath string

Path to the checkpoint file.

encoder INeuralNetwork<T>

The encoder network to load weights into.

methodFactory Func<INeuralNetwork<T>, ISSLMethod<T>>

Factory to create the SSL method from the encoder.

Returns

SSLSession<T>

A session restored from the checkpoint.

GetEffectiveBatchSize()

Gets the effective batch size considering distributed training.

public int GetEffectiveBatchSize()

Returns

int

Remarks

For DDP, the effective batch size is local_batch_size * world_size. This is important for correctly scaling the learning rate.

GetHistory()

Gets the current training history.

public SSLTrainingHistory<T> GetHistory()

Returns

SSLTrainingHistory<T>

Reset()

Resets the session for a new training run.

public void Reset()

SaveCheckpoint(string)

Saves a checkpoint to disk.

public void SaveCheckpoint(string checkpointPath)

Parameters

checkpointPath string

Path to save the checkpoint.

Stop()

Stops training gracefully.

public void Stop()

SynchronizeParameters()

Synchronizes model parameters across all distributed workers.

public void SynchronizeParameters()

Remarks

Call this method to ensure all workers have identical parameters. This is useful at initialization or after loading checkpoints.

Train(Func<IEnumerable<Tensor<T>>>, Tensor<T>?, int[]?)

Trains the SSL method for the specified number of epochs.

public SSLResult<T> Train(Func<IEnumerable<Tensor<T>>> dataLoader, Tensor<T>? validationData = null, int[]? validationLabels = null)

Parameters

dataLoader Func<IEnumerable<Tensor<T>>>

Function that yields batches of data.

validationData Tensor<T>

Optional validation data for k-NN evaluation.

validationLabels int[]

Optional validation labels.

Returns

SSLResult<T>

Training result.

Events

OnCollapseDetected

Event raised when collapse is detected.

public event Action<int>? OnCollapseDetected

Event Type

Action<int>

OnEpochEnd

Event raised at the end of each epoch.

public event Action<int, T>? OnEpochEnd

Event Type

Action<int, T>

OnEpochStart

Event raised at the start of each epoch.

public event Action<int>? OnEpochStart

Event Type

Action<int>

OnStepComplete

Event raised after each training step.

public event Action<int, SSLStepResult<T>>? OnStepComplete

Event Type

Action<int, SSLStepResult<T>>