Table of Contents

Class TeacherStudentSSL<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Base class for teacher-student self-supervised learning methods.

public abstract class TeacherStudentSSL<T> : SSLMethodBase<T>, ISSLMethod<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
TeacherStudentSSL<T>
Implements
Derived
Inherited Members

Remarks

For Beginners: Teacher-student SSL methods use two networks: a student that learns from gradients and a teacher that provides targets. The teacher is typically updated as an exponential moving average (EMA) of the student, providing stable learning targets.

Common components:

  • Student network: Trained with backpropagation
  • Teacher network: Updated with EMA (momentum encoder)
  • Centering: Prevents collapse by centering teacher outputs
  • Multi-crop: Uses multiple augmented views of different sizes

Methods using this pattern: DINO, iBOT, EsViT, DINOv2

Constructors

TeacherStudentSSL(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>, IProjectorHead<T>, int, SSLConfig?)

Initializes a new instance of the TeacherStudentSSL class.

protected TeacherStudentSSL(INeuralNetwork<T> studentEncoder, IMomentumEncoder<T> teacherEncoder, IProjectorHead<T> studentProjector, IProjectorHead<T> teacherProjector, int outputDim, SSLConfig? config = null)

Parameters

studentEncoder INeuralNetwork<T>

The student encoder network.

teacherEncoder IMomentumEncoder<T>

The teacher encoder (momentum-updated).

studentProjector IProjectorHead<T>

Projection head for student.

teacherProjector IProjectorHead<T>

Projection head for teacher.

outputDim int

Output dimension for centering.

config SSLConfig

Optional SSL configuration.

Fields

Augmentation

Augmentation policies for creating views.

protected readonly SSLAugmentationPolicies<T> Augmentation

Field Value

SSLAugmentationPolicies<T>

BaseMomentum

Base momentum value for teacher updates.

protected readonly double BaseMomentum

Field Value

double

Centering

Centering mechanism to prevent collapse.

protected readonly CenteringMechanism<T> Centering

Field Value

CenteringMechanism<T>

TeacherEncoder

The teacher encoder (momentum-updated copy of student).

protected readonly IMomentumEncoder<T> TeacherEncoder

Field Value

IMomentumEncoder<T>

TeacherProjector

The teacher projection head.

protected readonly IProjectorHead<T> TeacherProjector

Field Value

IProjectorHead<T>

Properties

NumGlobalCrops

Number of global crops (larger views used by both student and teacher).

protected int NumGlobalCrops { get; set; }

Property Value

int

NumLocalCrops

Number of local crops (smaller views used by student only).

protected int NumLocalCrops { get; set; }

Property Value

int

RequiresMemoryBank

Indicates whether this method requires a memory bank for negative samples.

public override bool RequiresMemoryBank { get; }

Property Value

bool

Remarks

For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.

UsesMomentumEncoder

Indicates whether this method uses a momentum-updated encoder.

public override bool UsesMomentumEncoder { get; }

Property Value

bool

Remarks

For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.

Methods

CreateMultiCropViews(Tensor<T>)

Creates augmented views for teacher-student training.

protected virtual (List<Tensor<T>> globalViews, List<Tensor<T>> localViews) CreateMultiCropViews(Tensor<T> batch)

Parameters

batch Tensor<T>

Input batch.

Returns

(List<Tensor<T>> globalViews, List<Tensor<T>> localViews)

Global and local crop views.

ForwardStudent(Tensor<T>)

Performs forward pass through student network.

protected virtual Tensor<T> ForwardStudent(Tensor<T> view)

Parameters

view Tensor<T>

Returns

Tensor<T>

ForwardTeacher(Tensor<T>)

Performs forward pass through teacher network (no gradients).

protected virtual Tensor<T> ForwardTeacher(Tensor<T> view)

Parameters

view Tensor<T>

Returns

Tensor<T>

GetAdditionalParameterCount()

Gets the count of additional parameters.

protected override int GetAdditionalParameterCount()

Returns

int

The number of additional parameters.

GetAdditionalParameters()

Gets additional parameters specific to this SSL method.

protected override Vector<T>? GetAdditionalParameters()

Returns

Vector<T>

Additional parameters, or null if none.

OnEpochStart(int)

Signals the start of a new epoch.

public override void OnEpochStart(int epochNumber)

Parameters

epochNumber int

The current epoch number.

UpdateStudent(T)

Updates student network parameters with gradients.

protected virtual void UpdateStudent(T learningRate)

Parameters

learningRate T

UpdateTeacher()

Updates teacher network with EMA from student.

protected virtual void UpdateTeacher()