Table of Contents

Interface ISSLMethod<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Defines the contract for self-supervised learning methods.

public interface ISSLMethod<T>

Type Parameters

T

The numeric type used for computations (typically float or double).

Remarks

For Beginners: Self-supervised learning methods learn useful representations from unlabeled data. They create "pretext tasks" that provide supervision signals without human labels.

Each SSL method implements this interface and provides:

  • A training step that processes batches and returns loss
  • Access to the learned encoder for downstream tasks
  • Encoding functionality to transform inputs into representations

Example usage:

// Create an SSL method
var simclr = new SimCLR<float>(encoder, config);

// Train for one step
var result = simclr.TrainStep(batch, augmentationContext);
Console.WriteLine($"Loss: {result.Loss}");

// Get learned representations
var embeddings = simclr.Encode(newData);

Properties

Category

Gets the category of this SSL method.

SSLMethodCategory Category { get; }

Property Value

SSLMethodCategory

Remarks

Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.

Name

Gets the name of this SSL method.

string Name { get; }

Property Value

string

Remarks

Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"

ParameterCount

Gets the total number of trainable parameters.

int ParameterCount { get; }

Property Value

int

RequiresMemoryBank

Indicates whether this method requires a memory bank for negative samples.

bool RequiresMemoryBank { get; }

Property Value

bool

Remarks

For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.

UsesMomentumEncoder

Indicates whether this method uses a momentum-updated encoder.

bool UsesMomentumEncoder { get; }

Property Value

bool

Remarks

For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.

Methods

Encode(Tensor<T>)

Encodes input data into learned representations.

Tensor<T> Encode(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor to encode.

Returns

Tensor<T>

The encoded representation tensor.

Remarks

For Beginners: After pretraining, use this to get representations for downstream tasks. The output embeddings can be used for classification, clustering, or similarity search.

GetEncoder()

Gets the underlying encoder neural network.

INeuralNetwork<T> GetEncoder()

Returns

INeuralNetwork<T>

The encoder network that produces representations.

Remarks

For Beginners: The encoder is the neural network that transforms raw inputs (like images) into learned representations. This is what you keep after pretraining.

GetParameters()

Gets the current parameters of the SSL method for serialization.

Vector<T> GetParameters()

Returns

Vector<T>

A vector containing all trainable parameters.

OnEpochEnd(int)

Called at the end of each training epoch.

void OnEpochEnd(int epochNumber)

Parameters

epochNumber int

The current epoch number (0-indexed).

Remarks

For Beginners: This method is called after each epoch completes. Methods use it for cleanup, logging, or updating statistics.

OnEpochStart(int)

Called at the start of each training epoch.

void OnEpochStart(int epochNumber)

Parameters

epochNumber int

The current epoch number (0-indexed).

Remarks

For Beginners: This method is called before each epoch begins. Methods use it to update learning rate schedules, momentum schedules, etc.

Reset()

Resets the SSL method to its initial state.

void Reset()

Remarks

This clears any accumulated state like memory banks, running statistics, and resets the momentum encoder if present.

SetParameters(Vector<T>)

Sets the parameters of the SSL method from a serialized vector.

void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to load.

TrainStep(Tensor<T>, SSLAugmentationContext<T>?)

Performs a single training step on a batch of data.

SSLStepResult<T> TrainStep(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext = null)

Parameters

batch Tensor<T>

The input batch tensor (e.g., images).

augmentationContext SSLAugmentationContext<T>

Optional context for augmentation (method may handle internally).

Returns

SSLStepResult<T>

The result of the training step including loss and metrics.

Remarks

For Beginners: This is the main training loop step. It:

  1. Creates augmented views of the input
  2. Passes views through the encoder
  3. Computes the SSL loss
  4. Updates model parameters