Class MoCoV3<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
MoCo v3: An Empirical Study of Training Self-Supervised Vision Transformers.
public class MoCoV3<T> : SSLMethodBase<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
MoCoV3<T>
- Implements
-
ISSLMethod<T>
- Inherited Members
Remarks
For Beginners: MoCo v3 adapts momentum contrastive learning specifically for Vision Transformers (ViT). It simplifies the framework by removing the memory queue and using in-batch negatives with a symmetric loss.
Key changes from MoCo v1/v2:
- No memory queue: Uses in-batch negatives like SimCLR
- Symmetric loss: Both views serve as queries and keys
- Prediction head: Adds a predictor MLP on one branch
- ViT optimizations: Random patch projection, no BN in MLP heads
Training stability for ViT:
- Uses lower learning rates and careful initialization
- Gradient clipping and careful warmup
- Momentum encoder still provides stable targets
Reference: Chen et al., "An Empirical Study of Training Self-Supervised Vision Transformers" (ICCV 2021)
Constructors
MoCoV3(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>, IProjectorHead<T>, IProjectorHead<T>?, SSLConfig?)
Initializes a new instance of the MoCoV3 class.
public MoCoV3(INeuralNetwork<T> encoder, IMomentumEncoder<T> momentumEncoder, IProjectorHead<T> projector, IProjectorHead<T> momentumProjector, IProjectorHead<T>? predictor = null, SSLConfig? config = null)
Parameters
encoderINeuralNetwork<T>The online encoder network (ViT recommended).
momentumEncoderIMomentumEncoder<T>The momentum encoder.
projectorIProjectorHead<T>Projection head for online encoder.
momentumProjectorIProjectorHead<T>Projection head for momentum encoder.
predictorIProjectorHead<T>Predictor head (applied to online branch only).
configSSLConfigOptional SSL configuration.
Properties
Category
Gets the category of this SSL method.
public override SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
Name
Gets the name of this SSL method.
public override string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public override bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public override bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
OnEpochStart(int)
Signals the start of a new epoch.
public override void OnEpochStart(int epochNumber)
Parameters
epochNumberintThe current epoch number.
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.