Class SymmetricProjector<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Symmetric Projector Head for BYOL and SimSiam-style methods.
public class SymmetricProjector<T> : IProjectorHead<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
SymmetricProjector<T>
- Implements
- Inherited Members
Remarks
For Beginners: The symmetric projector is used in BYOL and SimSiam. It consists of a projector MLP followed by a predictor MLP. The predictor creates asymmetry between online and target branches, which is key to avoiding collapse.
Architecture:
- Projector: Linear → BN → ReLU → Linear → BN
- Predictor: Linear → BN → ReLU → Linear
Key insight: The predictor is only applied to the online branch, creating asymmetry. The target branch only uses the projector.
Constructors
SymmetricProjector(int, int, int, int, int?)
Initializes a new instance of the SymmetricProjector class.
public SymmetricProjector(int inputDim, int hiddenDim = 4096, int projectionDim = 256, int predictorHiddenDim = 4096, int? seed = null)
Parameters
inputDimintInput dimension from encoder.
hiddenDimintHidden dimension of the projector (default: 4096).
projectionDimintOutput dimension (default: 256).
predictorHiddenDimintHidden dimension of predictor (default: 4096). Set to 0 to disable predictor.
seedint?Random seed for initialization.
Properties
HasPredictor
Gets whether this projector has a predictor head.
public bool HasPredictor { get; }
Property Value
HiddenDimension
Gets the hidden dimension (for MLP projectors).
public int? HiddenDimension { get; }
Property Value
- int?
Remarks
Typical values: 2048-4096. Usually larger than output dimension.
InputDimension
Gets the input dimension expected by this projector.
public int InputDimension { get; }
Property Value
OutputDimension
Gets the output dimension produced by this projector.
public int OutputDimension { get; }
Property Value
Remarks
Typical values: 128-2048. SimCLR uses 128, MoCo uses 128, BYOL uses 256.
ParameterCount
Gets the total number of trainable parameters.
public int ParameterCount { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs the backward pass through the projector.
public Tensor<T> Backward(Tensor<T> gradOutput)
Parameters
gradOutputTensor<T>
Returns
- Tensor<T>
The gradients with respect to projector input (for encoder backprop).
ClearGradients()
Clears accumulated gradients.
public void ClearGradients()
GetParameterGradients()
Gets the gradients computed during the last backward pass.
public Vector<T> GetParameterGradients()
Returns
- Vector<T>
A vector containing gradients for all parameters.
GetParameters()
Gets all trainable parameters of the projector.
public Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all parameters.
Predict(Tensor<T>)
Applies the predictor head (for online branch only).
public Tensor<T> Predict(Tensor<T> projection)
Parameters
projectionTensor<T>Output from the projector.
Returns
- Tensor<T>
Prediction output.
Project(Tensor<T>)
Projects encoder output to the SSL embedding space.
public Tensor<T> Project(Tensor<T> input)
Parameters
inputTensor<T>The encoder output tensor.
Returns
- Tensor<T>
The projected embedding tensor.
Remarks
For Beginners: This transforms encoder features into a lower-dimensional space where the SSL loss is computed. The projection helps separate the pretraining objective from the learned representations.
ProjectAndPredict(Tensor<T>)
Projects and predicts in one call (convenience method).
public Tensor<T> ProjectAndPredict(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
Reset()
Resets the projector state (clears any internal buffers).
public void Reset()
SetParameters(Vector<T>)
Sets the parameters of the projector.
public void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to load.
SetTrainingMode(bool)
Sets training or evaluation mode.
public void SetTrainingMode(bool isTraining)
Parameters
isTrainingboolTrue for training mode, false for evaluation.