Table of Contents

Class SSLMethodBase<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Abstract base class for self-supervised learning methods.

public abstract class SSLMethodBase<T> : ISSLMethod<T>

Type Parameters

T

The numeric type used for computations (typically float or double).

Inheritance
SSLMethodBase<T>
Implements
Derived
Inherited Members

Remarks

For Beginners: This base class provides common functionality shared by all SSL methods, including parameter management, training mode control, and configuration handling.

Derived classes (SimCLR, MoCo, BYOL, etc.) implement the specific training logic in the TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?) method.

Constructors

SSLMethodBase(INeuralNetwork<T>, IProjectorHead<T>?, SSLConfig?)

Initializes a new instance of the SSLMethodBase class.

protected SSLMethodBase(INeuralNetwork<T> encoder, IProjectorHead<T>? projector, SSLConfig? config)

Parameters

encoder INeuralNetwork<T>

The encoder neural network.

projector IProjectorHead<T>

Optional projection head.

config SSLConfig

SSL configuration.

Fields

NumOps

Numeric operations for generic type T.

protected static readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

_config

The SSL configuration.

protected readonly SSLConfig _config

Field Value

SSLConfig

_currentEpoch

Current epoch counter.

protected int _currentEpoch

Field Value

int

_currentStep

Current training step counter.

protected int _currentStep

Field Value

int

_encoder

The main encoder neural network.

protected readonly INeuralNetwork<T> _encoder

Field Value

INeuralNetwork<T>

_isTraining

Whether the method is in training mode.

protected bool _isTraining

Field Value

bool

_projector

The projection head for SSL embeddings.

protected readonly IProjectorHead<T>? _projector

Field Value

IProjectorHead<T>

Properties

Category

Gets the category of this SSL method.

public abstract SSLMethodCategory Category { get; }

Property Value

SSLMethodCategory

Remarks

Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.

Engine

Gets the global execution engine for vector operations and GPU/CPU acceleration.

protected IEngine Engine { get; }

Property Value

IEngine

Remarks

For Beginners: The engine handles hardware-accelerated computations. It automatically selects the best available hardware (GPU if available, otherwise CPU) for matrix operations, making SSL training much faster.

Name

Gets the name of this SSL method.

public abstract string Name { get; }

Property Value

string

Remarks

Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"

ParameterCount

Gets the total number of trainable parameters.

public int ParameterCount { get; }

Property Value

int

RequiresMemoryBank

Indicates whether this method requires a memory bank for negative samples.

public abstract bool RequiresMemoryBank { get; }

Property Value

bool

Remarks

For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.

UsesMomentumEncoder

Indicates whether this method uses a momentum-updated encoder.

public abstract bool UsesMomentumEncoder { get; }

Property Value

bool

Remarks

For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.

Methods

ComputePairwiseDistances(Tensor<T>)

Computes the pairwise squared distances between embeddings.

protected virtual Tensor<T> ComputePairwiseDistances(Tensor<T> embeddings)

Parameters

embeddings Tensor<T>

Embeddings [N, D].

Returns

Tensor<T>

Distance matrix [N, N].

ComputeSimilarityMatrix(Tensor<T>, Tensor<T>, bool)

Computes similarity matrix between two sets of embeddings.

protected virtual Tensor<T> ComputeSimilarityMatrix(Tensor<T> embeddings1, Tensor<T> embeddings2, bool normalize = true)

Parameters

embeddings1 Tensor<T>

First set of embeddings [N, D].

embeddings2 Tensor<T>

Second set of embeddings [M, D].

normalize bool

Whether to L2-normalize before computing similarity.

Returns

Tensor<T>

Similarity matrix [N, M].

CosineSimilarity(Tensor<T>, Tensor<T>)

Computes cosine similarity between two tensors.

protected virtual Tensor<T> CosineSimilarity(Tensor<T> a, Tensor<T> b)

Parameters

a Tensor<T>

First tensor [batch, dim].

b Tensor<T>

Second tensor [batch, dim].

Returns

Tensor<T>

Cosine similarity values [batch].

CreateStepResult(T)

Creates a default step result with common metrics.

protected SSLStepResult<T> CreateStepResult(T loss)

Parameters

loss T

The loss value.

Returns

SSLStepResult<T>

A step result with populated common fields.

Encode(Tensor<T>)

Encodes input data into learned representations.

public virtual Tensor<T> Encode(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor to encode.

Returns

Tensor<T>

The encoded representation tensor.

Remarks

For Beginners: After pretraining, use this to get representations for downstream tasks. The output embeddings can be used for classification, clustering, or similarity search.

EncodeAndProject(Tensor<T>)

Encodes input and projects it to the SSL embedding space.

protected virtual Tensor<T> EncodeAndProject(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor.

Returns

Tensor<T>

The projected embedding.

GetAdditionalParameterCount()

Gets the count of additional parameters.

protected virtual int GetAdditionalParameterCount()

Returns

int

The number of additional parameters.

GetAdditionalParameters()

Gets additional parameters specific to this SSL method.

protected virtual Vector<T>? GetAdditionalParameters()

Returns

Vector<T>

Additional parameters, or null if none.

GetEffectiveLearningRate()

Gets the effective learning rate based on configuration and scheduling.

public virtual double GetEffectiveLearningRate()

Returns

double

The current learning rate.

GetEffectiveTemperature()

Gets the effective temperature based on configuration and scheduling.

protected virtual double GetEffectiveTemperature()

Returns

double

The current temperature value.

GetEncoder()

Gets the underlying encoder neural network.

public INeuralNetwork<T> GetEncoder()

Returns

INeuralNetwork<T>

The encoder network that produces representations.

Remarks

For Beginners: The encoder is the neural network that transforms raw inputs (like images) into learned representations. This is what you keep after pretraining.

GetParameters()

Gets the current parameters of the SSL method for serialization.

public virtual Vector<T> GetParameters()

Returns

Vector<T>

A vector containing all trainable parameters.

L2Normalize(Tensor<T>)

L2-normalizes a tensor along the last dimension.

protected virtual Tensor<T> L2Normalize(Tensor<T> tensor)

Parameters

tensor Tensor<T>

The tensor to normalize [batch, dim].

Returns

Tensor<T>

The normalized tensor.

MatMul(Tensor<T>, Tensor<T>)

Computes matrix multiplication with engine-accelerated dot products.

protected virtual Tensor<T> MatMul(Tensor<T> a, Tensor<T> b)

Parameters

a Tensor<T>

First matrix [M, K].

b Tensor<T>

Second matrix [K, N].

Returns

Tensor<T>

Result matrix [M, N].

OnEpochEnd(int)

Signals the end of an epoch.

public virtual void OnEpochEnd(int epochNumber)

Parameters

epochNumber int

The completed epoch number.

OnEpochStart(int)

Signals the start of a new epoch.

public virtual void OnEpochStart(int epochNumber)

Parameters

epochNumber int

The current epoch number.

Reset()

Resets the SSL method to its initial state.

public virtual void Reset()

Remarks

This clears any accumulated state like memory banks, running statistics, and resets the momentum encoder if present.

SetAdditionalParameters(Vector<T>, ref int)

Sets additional parameters specific to this SSL method.

protected virtual void SetAdditionalParameters(Vector<T> parameters, ref int offset)

Parameters

parameters Vector<T>

The full parameter vector.

offset int

The current offset into the parameter vector.

SetParameters(Vector<T>)

Sets the parameters of the SSL method from a serialized vector.

public virtual void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to load.

SetTrainingMode(bool)

Sets the training mode for the SSL method.

public virtual void SetTrainingMode(bool isTraining)

Parameters

isTraining bool

True for training mode, false for evaluation.

TrainStep(Tensor<T>, SSLAugmentationContext<T>?)

Performs a single training step on a batch of data.

public SSLStepResult<T> TrainStep(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext = null)

Parameters

batch Tensor<T>

The input batch tensor (e.g., images).

augmentationContext SSLAugmentationContext<T>

Optional context for augmentation (method may handle internally).

Returns

SSLStepResult<T>

The result of the training step including loss and metrics.

Remarks

For Beginners: This is the main training loop step. It:

  1. Creates augmented views of the input
  2. Passes views through the encoder
  3. Computes the SSL loss
  4. Updates model parameters

TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)

Implementation-specific training step logic.

protected abstract SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)

Parameters

batch Tensor<T>

The input batch tensor.

augmentationContext SSLAugmentationContext<T>

Optional augmentation context.

Returns

SSLStepResult<T>

The result of the training step.