Class MoCo<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
MoCo: Momentum Contrast for Unsupervised Visual Representation Learning.
public class MoCo<T> : SSLMethodBase<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
MoCo<T>
- Implements
-
ISSLMethod<T>
- Derived
- Inherited Members
Remarks
For Beginners: MoCo is a contrastive learning method that uses a momentum encoder and a memory queue to provide a large pool of consistent negative samples without requiring huge batch sizes.
Key innovations:
- Momentum Encoder: A slowly-updating copy of the encoder for consistent keys
- Memory Queue: Stores past embeddings as negative samples (65536 by default)
- Asymmetric design: Query from main encoder, keys from momentum encoder
How MoCo works:
- Pass query image through online encoder → query q
- Pass key image through momentum encoder → positive key k+
- Get negative keys k- from memory queue
- Compute InfoNCE loss: pull q closer to k+, push away from k-
- Update momentum encoder with EMA
- Enqueue new keys, dequeue oldest
Reference: He et al., "Momentum Contrast for Unsupervised Visual Representation Learning" (CVPR 2020)
Constructors
MoCo(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>?, IProjectorHead<T>?, int, SSLConfig?)
Initializes a new instance of the MoCo class.
public MoCo(INeuralNetwork<T> encoder, IMomentumEncoder<T> momentumEncoder, IProjectorHead<T>? projector = null, IProjectorHead<T>? momentumProjector = null, int embeddingDim = 128, SSLConfig? config = null)
Parameters
encoderINeuralNetwork<T>The online encoder network.
momentumEncoderIMomentumEncoder<T>The momentum encoder (copy of main encoder).
projectorIProjectorHead<T>Optional projection head for online encoder.
momentumProjectorIProjectorHead<T>Optional projection head for momentum encoder.
embeddingDimintDimension of embeddings for memory bank.
configSSLConfigOptional SSL configuration.
Properties
Category
Gets the category of this SSL method.
public override SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
MemoryBank
Gets the memory bank used for negative samples.
public IMemoryBank<T> MemoryBank { get; }
Property Value
- IMemoryBank<T>
Name
Gets the name of this SSL method.
public override string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public override bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public override bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
GetAdditionalParameterCount()
Gets the count of additional parameters.
protected override int GetAdditionalParameterCount()
Returns
- int
The number of additional parameters.
GetAdditionalParameters()
Gets additional parameters specific to this SSL method.
protected override Vector<T>? GetAdditionalParameters()
Returns
- Vector<T>
Additional parameters, or null if none.
Reset()
Resets the SSL method to its initial state.
public override void Reset()
Remarks
This clears any accumulated state like memory banks, running statistics, and resets the momentum encoder if present.
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.