Table of Contents

Class CenteringMechanism<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Centering mechanism for preventing collapse in self-distillation methods.

public class CenteringMechanism<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
CenteringMechanism<T>
Inherited Members

Remarks

For Beginners: Centering is a crucial technique in DINO and similar methods that prevents the teacher network from collapsing to a trivial solution where it outputs the same constant for all inputs.

How it works:

  • Maintains a running mean (center) of teacher outputs
  • Subtracts the center from teacher outputs before computing loss
  • Updates the center with exponential moving average (EMA)

Why it prevents collapse: Without centering, the teacher could learn to output a constant vector for all inputs (trivial solution). By subtracting the running mean, we ensure the outputs are zero-centered on average, forcing the network to produce varied outputs.

Reference: Caron et al., "Emerging Properties in Self-Supervised Vision Transformers" (ICCV 2021)

Constructors

CenteringMechanism(int, double)

Initializes a new instance of the CenteringMechanism class.

public CenteringMechanism(int dimension, double momentum = 0.9)

Parameters

dimension int

Dimension of the output space to center.

momentum double

Momentum for EMA center updates (default: 0.9).

Properties

Dimension

Gets the dimension of the center vector.

public int Dimension { get; }

Property Value

int

Momentum

Gets the momentum for EMA updates.

public double Momentum { get; }

Property Value

double

Methods

ApplyCenter(Tensor<T>)

Applies centering to the input tensor.

public Tensor<T> ApplyCenter(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor [batch_size, dim].

Returns

Tensor<T>

Centered tensor.

CenterAndUpdate(Tensor<T>)

Applies centering and updates in one step (common usage pattern).

public Tensor<T> CenterAndUpdate(Tensor<T> teacherOutput)

Parameters

teacherOutput Tensor<T>

Teacher network output.

Returns

Tensor<T>

Centered output.

CenterNorm()

Computes the L2 norm of the center (useful for monitoring).

public T CenterNorm()

Returns

T

CenterStatistics()

Computes statistics about the center (useful for debugging).

public (T mean, T std, T min, T max) CenterStatistics()

Returns

(T mean, T std, T min, T max)

GetCenter()

Gets the current center values.

public T[] GetCenter()

Returns

T[]

Copy of the center vector.

Reset()

Resets the center to zeros.

public void Reset()

SetCenter(T[])

Sets the center values directly.

public void SetCenter(T[] center)

Parameters

center T[]

New center values.

Update(Tensor<T>)

Updates the center using EMA with the given batch.

public void Update(Tensor<T> batchOutput)

Parameters

batchOutput Tensor<T>

Batch of outputs from teacher network [batch_size, dim].

UpdateFromMultiple(IList<Tensor<T>>)

Updates the center using multiple batches of outputs.

public void UpdateFromMultiple(IList<Tensor<T>> outputs)

Parameters

outputs IList<Tensor<T>>

List of output tensors.