Interface ISSLMethod<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Defines the contract for self-supervised learning methods.
public interface ISSLMethod<T>
Type Parameters
TThe numeric type used for computations (typically float or double).
Remarks
For Beginners: Self-supervised learning methods learn useful representations from unlabeled data. They create "pretext tasks" that provide supervision signals without human labels.
Each SSL method implements this interface and provides:
- A training step that processes batches and returns loss
- Access to the learned encoder for downstream tasks
- Encoding functionality to transform inputs into representations
Example usage:
// Create an SSL method
var simclr = new SimCLR<float>(encoder, config);
// Train for one step
var result = simclr.TrainStep(batch, augmentationContext);
Console.WriteLine($"Loss: {result.Loss}");
// Get learned representations
var embeddings = simclr.Encode(newData);
Properties
Category
Gets the category of this SSL method.
SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
Name
Gets the name of this SSL method.
string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
ParameterCount
Gets the total number of trainable parameters.
int ParameterCount { get; }
Property Value
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
Encode(Tensor<T>)
Encodes input data into learned representations.
Tensor<T> Encode(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to encode.
Returns
- Tensor<T>
The encoded representation tensor.
Remarks
For Beginners: After pretraining, use this to get representations for downstream tasks. The output embeddings can be used for classification, clustering, or similarity search.
GetEncoder()
Gets the underlying encoder neural network.
INeuralNetwork<T> GetEncoder()
Returns
- INeuralNetwork<T>
The encoder network that produces representations.
Remarks
For Beginners: The encoder is the neural network that transforms raw inputs (like images) into learned representations. This is what you keep after pretraining.
GetParameters()
Gets the current parameters of the SSL method for serialization.
Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all trainable parameters.
OnEpochEnd(int)
Called at the end of each training epoch.
void OnEpochEnd(int epochNumber)
Parameters
epochNumberintThe current epoch number (0-indexed).
Remarks
For Beginners: This method is called after each epoch completes. Methods use it for cleanup, logging, or updating statistics.
OnEpochStart(int)
Called at the start of each training epoch.
void OnEpochStart(int epochNumber)
Parameters
epochNumberintThe current epoch number (0-indexed).
Remarks
For Beginners: This method is called before each epoch begins. Methods use it to update learning rate schedules, momentum schedules, etc.
Reset()
Resets the SSL method to its initial state.
void Reset()
Remarks
This clears any accumulated state like memory banks, running statistics, and resets the momentum encoder if present.
SetParameters(Vector<T>)
Sets the parameters of the SSL method from a serialized vector.
void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to load.
TrainStep(Tensor<T>, SSLAugmentationContext<T>?)
Performs a single training step on a batch of data.
SSLStepResult<T> TrainStep(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext = null)
Parameters
batchTensor<T>The input batch tensor (e.g., images).
augmentationContextSSLAugmentationContext<T>Optional context for augmentation (method may handle internally).
Returns
- SSLStepResult<T>
The result of the training step including loss and metrics.
Remarks
For Beginners: This is the main training loop step. It:
- Creates augmented views of the input
- Passes views through the encoder
- Computes the SSL loss
- Updates model parameters