Class CenteringMechanism<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Centering mechanism for preventing collapse in self-distillation methods.
public class CenteringMechanism<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
CenteringMechanism<T>
- Inherited Members
Remarks
For Beginners: Centering is a crucial technique in DINO and similar methods that prevents the teacher network from collapsing to a trivial solution where it outputs the same constant for all inputs.
How it works:
- Maintains a running mean (center) of teacher outputs
- Subtracts the center from teacher outputs before computing loss
- Updates the center with exponential moving average (EMA)
Why it prevents collapse: Without centering, the teacher could learn to output a constant vector for all inputs (trivial solution). By subtracting the running mean, we ensure the outputs are zero-centered on average, forcing the network to produce varied outputs.
Reference: Caron et al., "Emerging Properties in Self-Supervised Vision Transformers" (ICCV 2021)
Constructors
CenteringMechanism(int, double)
Initializes a new instance of the CenteringMechanism class.
public CenteringMechanism(int dimension, double momentum = 0.9)
Parameters
dimensionintDimension of the output space to center.
momentumdoubleMomentum for EMA center updates (default: 0.9).
Properties
Dimension
Gets the dimension of the center vector.
public int Dimension { get; }
Property Value
Momentum
Gets the momentum for EMA updates.
public double Momentum { get; }
Property Value
Methods
ApplyCenter(Tensor<T>)
Applies centering to the input tensor.
public Tensor<T> ApplyCenter(Tensor<T> input)
Parameters
inputTensor<T>Input tensor [batch_size, dim].
Returns
- Tensor<T>
Centered tensor.
CenterAndUpdate(Tensor<T>)
Applies centering and updates in one step (common usage pattern).
public Tensor<T> CenterAndUpdate(Tensor<T> teacherOutput)
Parameters
teacherOutputTensor<T>Teacher network output.
Returns
- Tensor<T>
Centered output.
CenterNorm()
Computes the L2 norm of the center (useful for monitoring).
public T CenterNorm()
Returns
- T
CenterStatistics()
Computes statistics about the center (useful for debugging).
public (T mean, T std, T min, T max) CenterStatistics()
Returns
GetCenter()
Gets the current center values.
public T[] GetCenter()
Returns
- T[]
Copy of the center vector.
Reset()
Resets the center to zeros.
public void Reset()
SetCenter(T[])
Sets the center values directly.
public void SetCenter(T[] center)
Parameters
centerT[]New center values.
Update(Tensor<T>)
Updates the center using EMA with the given batch.
public void Update(Tensor<T> batchOutput)
Parameters
batchOutputTensor<T>Batch of outputs from teacher network [batch_size, dim].
UpdateFromMultiple(IList<Tensor<T>>)
Updates the center using multiple batches of outputs.
public void UpdateFromMultiple(IList<Tensor<T>> outputs)
Parameters
outputsIList<Tensor<T>>List of output tensors.