Class SimSiam<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
SimSiam: Exploring Simple Siamese Representation Learning.
public class SimSiam<T> : SSLMethodBase<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
SimSiam<T>
- Implements
-
ISSLMethod<T>
- Inherited Members
Remarks
For Beginners: SimSiam shows that simple Siamese networks can learn meaningful representations without negative pairs, momentum encoder, or large batches. The key is the stop-gradient operation applied to one branch.
Key innovations:
- No negatives: Like BYOL, doesn't need negative samples
- No momentum encoder: Unlike BYOL, uses the same weights for both branches
- No large batches: Works with batch sizes as small as 256
- Stop-gradient: The key to preventing collapse
Architecture:
Branch 1: encoder → projector → predictor → p₁
Branch 2: encoder → projector → z₂ (stop-gradient)
Loss: D(p₁, stopgrad(z₂)) + D(p₂, stopgrad(z₁))
Why it works: The stop-gradient prevents both branches from collapsing to the same constant output. The predictor makes one branch "predict" the other, creating useful gradients for learning.
Reference: Chen and He, "Exploring Simple Siamese Representation Learning" (CVPR 2021)
Constructors
SimSiam(INeuralNetwork<T>, SymmetricProjector<T>, SSLConfig?)
Initializes a new instance of the SimSiam class.
public SimSiam(INeuralNetwork<T> encoder, SymmetricProjector<T> projector, SSLConfig? config = null)
Parameters
encoderINeuralNetwork<T>The encoder network (shared between both branches).
projectorSymmetricProjector<T>Symmetric projector with predictor.
configSSLConfigOptional SSL configuration.
Properties
Category
Gets the category of this SSL method.
public override SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
Name
Gets the name of this SSL method.
public override string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public override bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public override bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
Create(INeuralNetwork<T>, int, int, int, int)
Creates a SimSiam instance with default configuration.
public static SimSiam<T> Create(INeuralNetwork<T> encoder, int encoderOutputDim, int projectionDim = 2048, int hiddenDim = 2048, int predictorHiddenDim = 512)
Parameters
encoderINeuralNetwork<T>The backbone encoder.
encoderOutputDimintOutput dimension of the encoder.
projectionDimintDimension of the projection space (default: 2048).
hiddenDimintHidden dimension of the projector MLP (default: 2048).
predictorHiddenDimintHidden dimension of the predictor (default: 512).
Returns
- SimSiam<T>
A configured SimSiam instance.
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.