Table of Contents

Class MoCoV3<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

MoCo v3: An Empirical Study of Training Self-Supervised Vision Transformers.

public class MoCoV3<T> : SSLMethodBase<T>, ISSLMethod<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
MoCoV3<T>
Implements
Inherited Members

Remarks

For Beginners: MoCo v3 adapts momentum contrastive learning specifically for Vision Transformers (ViT). It simplifies the framework by removing the memory queue and using in-batch negatives with a symmetric loss.

Key changes from MoCo v1/v2:

  • No memory queue: Uses in-batch negatives like SimCLR
  • Symmetric loss: Both views serve as queries and keys
  • Prediction head: Adds a predictor MLP on one branch
  • ViT optimizations: Random patch projection, no BN in MLP heads

Training stability for ViT:

  • Uses lower learning rates and careful initialization
  • Gradient clipping and careful warmup
  • Momentum encoder still provides stable targets

Reference: Chen et al., "An Empirical Study of Training Self-Supervised Vision Transformers" (ICCV 2021)

Constructors

MoCoV3(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>, IProjectorHead<T>, IProjectorHead<T>?, SSLConfig?)

Initializes a new instance of the MoCoV3 class.

public MoCoV3(INeuralNetwork<T> encoder, IMomentumEncoder<T> momentumEncoder, IProjectorHead<T> projector, IProjectorHead<T> momentumProjector, IProjectorHead<T>? predictor = null, SSLConfig? config = null)

Parameters

encoder INeuralNetwork<T>

The online encoder network (ViT recommended).

momentumEncoder IMomentumEncoder<T>

The momentum encoder.

projector IProjectorHead<T>

Projection head for online encoder.

momentumProjector IProjectorHead<T>

Projection head for momentum encoder.

predictor IProjectorHead<T>

Predictor head (applied to online branch only).

config SSLConfig

Optional SSL configuration.

Properties

Category

Gets the category of this SSL method.

public override SSLMethodCategory Category { get; }

Property Value

SSLMethodCategory

Remarks

Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.

Name

Gets the name of this SSL method.

public override string Name { get; }

Property Value

string

Remarks

Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"

RequiresMemoryBank

Indicates whether this method requires a memory bank for negative samples.

public override bool RequiresMemoryBank { get; }

Property Value

bool

Remarks

For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.

UsesMomentumEncoder

Indicates whether this method uses a momentum-updated encoder.

public override bool UsesMomentumEncoder { get; }

Property Value

bool

Remarks

For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.

Methods

OnEpochStart(int)

Signals the start of a new epoch.

public override void OnEpochStart(int epochNumber)

Parameters

epochNumber int

The current epoch number.

TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)

Implementation-specific training step logic.

protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)

Parameters

batch Tensor<T>

The input batch tensor.

augmentationContext SSLAugmentationContext<T>

Optional augmentation context.

Returns

SSLStepResult<T>

The result of the training step.