Table of Contents

Class SymmetricProjector<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Symmetric Projector Head for BYOL and SimSiam-style methods.

public class SymmetricProjector<T> : IProjectorHead<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
SymmetricProjector<T>
Implements
Inherited Members

Remarks

For Beginners: The symmetric projector is used in BYOL and SimSiam. It consists of a projector MLP followed by a predictor MLP. The predictor creates asymmetry between online and target branches, which is key to avoiding collapse.

Architecture:

  • Projector: Linear → BN → ReLU → Linear → BN
  • Predictor: Linear → BN → ReLU → Linear

Key insight: The predictor is only applied to the online branch, creating asymmetry. The target branch only uses the projector.

Constructors

SymmetricProjector(int, int, int, int, int?)

Initializes a new instance of the SymmetricProjector class.

public SymmetricProjector(int inputDim, int hiddenDim = 4096, int projectionDim = 256, int predictorHiddenDim = 4096, int? seed = null)

Parameters

inputDim int

Input dimension from encoder.

hiddenDim int

Hidden dimension of the projector (default: 4096).

projectionDim int

Output dimension (default: 256).

predictorHiddenDim int

Hidden dimension of predictor (default: 4096). Set to 0 to disable predictor.

seed int?

Random seed for initialization.

Properties

HasPredictor

Gets whether this projector has a predictor head.

public bool HasPredictor { get; }

Property Value

bool

HiddenDimension

Gets the hidden dimension (for MLP projectors).

public int? HiddenDimension { get; }

Property Value

int?

Remarks

Typical values: 2048-4096. Usually larger than output dimension.

InputDimension

Gets the input dimension expected by this projector.

public int InputDimension { get; }

Property Value

int

OutputDimension

Gets the output dimension produced by this projector.

public int OutputDimension { get; }

Property Value

int

Remarks

Typical values: 128-2048. SimCLR uses 128, MoCo uses 128, BYOL uses 256.

ParameterCount

Gets the total number of trainable parameters.

public int ParameterCount { get; }

Property Value

int

Methods

Backward(Tensor<T>)

Performs the backward pass through the projector.

public Tensor<T> Backward(Tensor<T> gradOutput)

Parameters

gradOutput Tensor<T>

Returns

Tensor<T>

The gradients with respect to projector input (for encoder backprop).

ClearGradients()

Clears accumulated gradients.

public void ClearGradients()

GetParameterGradients()

Gets the gradients computed during the last backward pass.

public Vector<T> GetParameterGradients()

Returns

Vector<T>

A vector containing gradients for all parameters.

GetParameters()

Gets all trainable parameters of the projector.

public Vector<T> GetParameters()

Returns

Vector<T>

A vector containing all parameters.

Predict(Tensor<T>)

Applies the predictor head (for online branch only).

public Tensor<T> Predict(Tensor<T> projection)

Parameters

projection Tensor<T>

Output from the projector.

Returns

Tensor<T>

Prediction output.

Project(Tensor<T>)

Projects encoder output to the SSL embedding space.

public Tensor<T> Project(Tensor<T> input)

Parameters

input Tensor<T>

The encoder output tensor.

Returns

Tensor<T>

The projected embedding tensor.

Remarks

For Beginners: This transforms encoder features into a lower-dimensional space where the SSL loss is computed. The projection helps separate the pretraining objective from the learned representations.

ProjectAndPredict(Tensor<T>)

Projects and predicts in one call (convenience method).

public Tensor<T> ProjectAndPredict(Tensor<T> input)

Parameters

input Tensor<T>

Returns

Tensor<T>

Reset()

Resets the projector state (clears any internal buffers).

public void Reset()

SetParameters(Vector<T>)

Sets the parameters of the projector.

public void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to load.

SetTrainingMode(bool)

Sets training or evaluation mode.

public void SetTrainingMode(bool isTraining)

Parameters

isTraining bool

True for training mode, false for evaluation.