Class SSLMethodBase<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Abstract base class for self-supervised learning methods.
public abstract class SSLMethodBase<T> : ISSLMethod<T>
Type Parameters
TThe numeric type used for computations (typically float or double).
- Inheritance
-
SSLMethodBase<T>
- Implements
-
ISSLMethod<T>
- Derived
- Inherited Members
Remarks
For Beginners: This base class provides common functionality shared by all SSL methods, including parameter management, training mode control, and configuration handling.
Derived classes (SimCLR, MoCo, BYOL, etc.) implement the specific training logic in the TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?) method.
Constructors
SSLMethodBase(INeuralNetwork<T>, IProjectorHead<T>?, SSLConfig?)
Initializes a new instance of the SSLMethodBase class.
protected SSLMethodBase(INeuralNetwork<T> encoder, IProjectorHead<T>? projector, SSLConfig? config)
Parameters
encoderINeuralNetwork<T>The encoder neural network.
projectorIProjectorHead<T>Optional projection head.
configSSLConfigSSL configuration.
Fields
NumOps
Numeric operations for generic type T.
protected static readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
_config
The SSL configuration.
protected readonly SSLConfig _config
Field Value
_currentEpoch
Current epoch counter.
protected int _currentEpoch
Field Value
_currentStep
Current training step counter.
protected int _currentStep
Field Value
_encoder
The main encoder neural network.
protected readonly INeuralNetwork<T> _encoder
Field Value
_isTraining
Whether the method is in training mode.
protected bool _isTraining
Field Value
_projector
The projection head for SSL embeddings.
protected readonly IProjectorHead<T>? _projector
Field Value
Properties
Category
Gets the category of this SSL method.
public abstract SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
Engine
Gets the global execution engine for vector operations and GPU/CPU acceleration.
protected IEngine Engine { get; }
Property Value
- IEngine
Remarks
For Beginners: The engine handles hardware-accelerated computations. It automatically selects the best available hardware (GPU if available, otherwise CPU) for matrix operations, making SSL training much faster.
Name
Gets the name of this SSL method.
public abstract string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
ParameterCount
Gets the total number of trainable parameters.
public int ParameterCount { get; }
Property Value
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public abstract bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public abstract bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
ComputePairwiseDistances(Tensor<T>)
Computes the pairwise squared distances between embeddings.
protected virtual Tensor<T> ComputePairwiseDistances(Tensor<T> embeddings)
Parameters
embeddingsTensor<T>Embeddings [N, D].
Returns
- Tensor<T>
Distance matrix [N, N].
ComputeSimilarityMatrix(Tensor<T>, Tensor<T>, bool)
Computes similarity matrix between two sets of embeddings.
protected virtual Tensor<T> ComputeSimilarityMatrix(Tensor<T> embeddings1, Tensor<T> embeddings2, bool normalize = true)
Parameters
embeddings1Tensor<T>First set of embeddings [N, D].
embeddings2Tensor<T>Second set of embeddings [M, D].
normalizeboolWhether to L2-normalize before computing similarity.
Returns
- Tensor<T>
Similarity matrix [N, M].
CosineSimilarity(Tensor<T>, Tensor<T>)
Computes cosine similarity between two tensors.
protected virtual Tensor<T> CosineSimilarity(Tensor<T> a, Tensor<T> b)
Parameters
aTensor<T>First tensor [batch, dim].
bTensor<T>Second tensor [batch, dim].
Returns
- Tensor<T>
Cosine similarity values [batch].
CreateStepResult(T)
Creates a default step result with common metrics.
protected SSLStepResult<T> CreateStepResult(T loss)
Parameters
lossTThe loss value.
Returns
- SSLStepResult<T>
A step result with populated common fields.
Encode(Tensor<T>)
Encodes input data into learned representations.
public virtual Tensor<T> Encode(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to encode.
Returns
- Tensor<T>
The encoded representation tensor.
Remarks
For Beginners: After pretraining, use this to get representations for downstream tasks. The output embeddings can be used for classification, clustering, or similarity search.
EncodeAndProject(Tensor<T>)
Encodes input and projects it to the SSL embedding space.
protected virtual Tensor<T> EncodeAndProject(Tensor<T> input)
Parameters
inputTensor<T>The input tensor.
Returns
- Tensor<T>
The projected embedding.
GetAdditionalParameterCount()
Gets the count of additional parameters.
protected virtual int GetAdditionalParameterCount()
Returns
- int
The number of additional parameters.
GetAdditionalParameters()
Gets additional parameters specific to this SSL method.
protected virtual Vector<T>? GetAdditionalParameters()
Returns
- Vector<T>
Additional parameters, or null if none.
GetEffectiveLearningRate()
Gets the effective learning rate based on configuration and scheduling.
public virtual double GetEffectiveLearningRate()
Returns
- double
The current learning rate.
GetEffectiveTemperature()
Gets the effective temperature based on configuration and scheduling.
protected virtual double GetEffectiveTemperature()
Returns
- double
The current temperature value.
GetEncoder()
Gets the underlying encoder neural network.
public INeuralNetwork<T> GetEncoder()
Returns
- INeuralNetwork<T>
The encoder network that produces representations.
Remarks
For Beginners: The encoder is the neural network that transforms raw inputs (like images) into learned representations. This is what you keep after pretraining.
GetParameters()
Gets the current parameters of the SSL method for serialization.
public virtual Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all trainable parameters.
L2Normalize(Tensor<T>)
L2-normalizes a tensor along the last dimension.
protected virtual Tensor<T> L2Normalize(Tensor<T> tensor)
Parameters
tensorTensor<T>The tensor to normalize [batch, dim].
Returns
- Tensor<T>
The normalized tensor.
MatMul(Tensor<T>, Tensor<T>)
Computes matrix multiplication with engine-accelerated dot products.
protected virtual Tensor<T> MatMul(Tensor<T> a, Tensor<T> b)
Parameters
aTensor<T>First matrix [M, K].
bTensor<T>Second matrix [K, N].
Returns
- Tensor<T>
Result matrix [M, N].
OnEpochEnd(int)
Signals the end of an epoch.
public virtual void OnEpochEnd(int epochNumber)
Parameters
epochNumberintThe completed epoch number.
OnEpochStart(int)
Signals the start of a new epoch.
public virtual void OnEpochStart(int epochNumber)
Parameters
epochNumberintThe current epoch number.
Reset()
Resets the SSL method to its initial state.
public virtual void Reset()
Remarks
This clears any accumulated state like memory banks, running statistics, and resets the momentum encoder if present.
SetAdditionalParameters(Vector<T>, ref int)
Sets additional parameters specific to this SSL method.
protected virtual void SetAdditionalParameters(Vector<T> parameters, ref int offset)
Parameters
parametersVector<T>The full parameter vector.
offsetintThe current offset into the parameter vector.
SetParameters(Vector<T>)
Sets the parameters of the SSL method from a serialized vector.
public virtual void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to load.
SetTrainingMode(bool)
Sets the training mode for the SSL method.
public virtual void SetTrainingMode(bool isTraining)
Parameters
isTrainingboolTrue for training mode, false for evaluation.
TrainStep(Tensor<T>, SSLAugmentationContext<T>?)
Performs a single training step on a batch of data.
public SSLStepResult<T> TrainStep(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext = null)
Parameters
batchTensor<T>The input batch tensor (e.g., images).
augmentationContextSSLAugmentationContext<T>Optional context for augmentation (method may handle internally).
Returns
- SSLStepResult<T>
The result of the training step including loss and metrics.
Remarks
For Beginners: This is the main training loop step. It:
- Creates augmented views of the input
- Passes views through the encoder
- Computes the SSL loss
- Updates model parameters
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected abstract SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.