Table of Contents

Class MoCo<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

MoCo: Momentum Contrast for Unsupervised Visual Representation Learning.

public class MoCo<T> : SSLMethodBase<T>, ISSLMethod<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
MoCo<T>
Implements
Derived
Inherited Members

Remarks

For Beginners: MoCo is a contrastive learning method that uses a momentum encoder and a memory queue to provide a large pool of consistent negative samples without requiring huge batch sizes.

Key innovations:

  • Momentum Encoder: A slowly-updating copy of the encoder for consistent keys
  • Memory Queue: Stores past embeddings as negative samples (65536 by default)
  • Asymmetric design: Query from main encoder, keys from momentum encoder

How MoCo works:

  1. Pass query image through online encoder → query q
  2. Pass key image through momentum encoder → positive key k+
  3. Get negative keys k- from memory queue
  4. Compute InfoNCE loss: pull q closer to k+, push away from k-
  5. Update momentum encoder with EMA
  6. Enqueue new keys, dequeue oldest

Reference: He et al., "Momentum Contrast for Unsupervised Visual Representation Learning" (CVPR 2020)

Constructors

MoCo(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>?, IProjectorHead<T>?, int, SSLConfig?)

Initializes a new instance of the MoCo class.

public MoCo(INeuralNetwork<T> encoder, IMomentumEncoder<T> momentumEncoder, IProjectorHead<T>? projector = null, IProjectorHead<T>? momentumProjector = null, int embeddingDim = 128, SSLConfig? config = null)

Parameters

encoder INeuralNetwork<T>

The online encoder network.

momentumEncoder IMomentumEncoder<T>

The momentum encoder (copy of main encoder).

projector IProjectorHead<T>

Optional projection head for online encoder.

momentumProjector IProjectorHead<T>

Optional projection head for momentum encoder.

embeddingDim int

Dimension of embeddings for memory bank.

config SSLConfig

Optional SSL configuration.

Properties

Category

Gets the category of this SSL method.

public override SSLMethodCategory Category { get; }

Property Value

SSLMethodCategory

Remarks

Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.

MemoryBank

Gets the memory bank used for negative samples.

public IMemoryBank<T> MemoryBank { get; }

Property Value

IMemoryBank<T>

Name

Gets the name of this SSL method.

public override string Name { get; }

Property Value

string

Remarks

Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"

RequiresMemoryBank

Indicates whether this method requires a memory bank for negative samples.

public override bool RequiresMemoryBank { get; }

Property Value

bool

Remarks

For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.

UsesMomentumEncoder

Indicates whether this method uses a momentum-updated encoder.

public override bool UsesMomentumEncoder { get; }

Property Value

bool

Remarks

For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.

Methods

GetAdditionalParameterCount()

Gets the count of additional parameters.

protected override int GetAdditionalParameterCount()

Returns

int

The number of additional parameters.

GetAdditionalParameters()

Gets additional parameters specific to this SSL method.

protected override Vector<T>? GetAdditionalParameters()

Returns

Vector<T>

Additional parameters, or null if none.

Reset()

Resets the SSL method to its initial state.

public override void Reset()

Remarks

This clears any accumulated state like memory banks, running statistics, and resets the momentum encoder if present.

TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)

Implementation-specific training step logic.

protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)

Parameters

batch Tensor<T>

The input batch tensor.

augmentationContext SSLAugmentationContext<T>

Optional augmentation context.

Returns

SSLStepResult<T>

The result of the training step.