Class TeacherStudentSSL<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Base class for teacher-student self-supervised learning methods.
public abstract class TeacherStudentSSL<T> : SSLMethodBase<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
TeacherStudentSSL<T>
- Implements
-
ISSLMethod<T>
- Derived
- Inherited Members
Remarks
For Beginners: Teacher-student SSL methods use two networks: a student that learns from gradients and a teacher that provides targets. The teacher is typically updated as an exponential moving average (EMA) of the student, providing stable learning targets.
Common components:
- Student network: Trained with backpropagation
- Teacher network: Updated with EMA (momentum encoder)
- Centering: Prevents collapse by centering teacher outputs
- Multi-crop: Uses multiple augmented views of different sizes
Methods using this pattern: DINO, iBOT, EsViT, DINOv2
Constructors
TeacherStudentSSL(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>, IProjectorHead<T>, int, SSLConfig?)
Initializes a new instance of the TeacherStudentSSL class.
protected TeacherStudentSSL(INeuralNetwork<T> studentEncoder, IMomentumEncoder<T> teacherEncoder, IProjectorHead<T> studentProjector, IProjectorHead<T> teacherProjector, int outputDim, SSLConfig? config = null)
Parameters
studentEncoderINeuralNetwork<T>The student encoder network.
teacherEncoderIMomentumEncoder<T>The teacher encoder (momentum-updated).
studentProjectorIProjectorHead<T>Projection head for student.
teacherProjectorIProjectorHead<T>Projection head for teacher.
outputDimintOutput dimension for centering.
configSSLConfigOptional SSL configuration.
Fields
Augmentation
Augmentation policies for creating views.
protected readonly SSLAugmentationPolicies<T> Augmentation
Field Value
BaseMomentum
Base momentum value for teacher updates.
protected readonly double BaseMomentum
Field Value
Centering
Centering mechanism to prevent collapse.
protected readonly CenteringMechanism<T> Centering
Field Value
TeacherEncoder
The teacher encoder (momentum-updated copy of student).
protected readonly IMomentumEncoder<T> TeacherEncoder
Field Value
TeacherProjector
The teacher projection head.
protected readonly IProjectorHead<T> TeacherProjector
Field Value
Properties
NumGlobalCrops
Number of global crops (larger views used by both student and teacher).
protected int NumGlobalCrops { get; set; }
Property Value
NumLocalCrops
Number of local crops (smaller views used by student only).
protected int NumLocalCrops { get; set; }
Property Value
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public override bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public override bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
CreateMultiCropViews(Tensor<T>)
Creates augmented views for teacher-student training.
protected virtual (List<Tensor<T>> globalViews, List<Tensor<T>> localViews) CreateMultiCropViews(Tensor<T> batch)
Parameters
batchTensor<T>Input batch.
Returns
- (List<Tensor<T>> globalViews, List<Tensor<T>> localViews)
Global and local crop views.
ForwardStudent(Tensor<T>)
Performs forward pass through student network.
protected virtual Tensor<T> ForwardStudent(Tensor<T> view)
Parameters
viewTensor<T>
Returns
- Tensor<T>
ForwardTeacher(Tensor<T>)
Performs forward pass through teacher network (no gradients).
protected virtual Tensor<T> ForwardTeacher(Tensor<T> view)
Parameters
viewTensor<T>
Returns
- Tensor<T>
GetAdditionalParameterCount()
Gets the count of additional parameters.
protected override int GetAdditionalParameterCount()
Returns
- int
The number of additional parameters.
GetAdditionalParameters()
Gets additional parameters specific to this SSL method.
protected override Vector<T>? GetAdditionalParameters()
Returns
- Vector<T>
Additional parameters, or null if none.
OnEpochStart(int)
Signals the start of a new epoch.
public override void OnEpochStart(int epochNumber)
Parameters
epochNumberintThe current epoch number.
UpdateStudent(T)
Updates student network parameters with gradients.
protected virtual void UpdateStudent(T learningRate)
Parameters
learningRateT
UpdateTeacher()
Updates teacher network with EMA from student.
protected virtual void UpdateTeacher()