Table of Contents

Class StopGradient<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

Provides stop-gradient operations for self-supervised learning.

public static class StopGradient<T>

Type Parameters

T

The numeric type used for computations (typically float or double).

Inheritance
StopGradient<T>
Inherited Members

Remarks

For Beginners: Stop-gradient (also called "detach" in PyTorch) prevents gradients from flowing through a tensor during backpropagation. This is crucial for several SSL methods:

  • SimSiam: Stop-gradient on one branch prevents collapse without momentum encoder
  • BYOL: Target network outputs are detached (no gradients to momentum encoder)
  • MoCo: Memory bank entries and momentum encoder outputs are detached

Why stop-gradient?

Without stop-gradient, the model could "cheat" by making both branches output constants, resulting in representation collapse. Stop-gradient forces asymmetry that prevents this.

Example usage:

// SimSiam: asymmetric loss with stop-gradient
var z1 = encoder(x1);  // Online branch
var z2 = encoder(x2);  // Online branch
var p1 = predictor(z1);
var p2 = predictor(z2);

// Loss with stop-gradient - gradients only flow through predictor side
var loss = -CosineSimilarity(p1, StopGradient.Detach(z2)).Mean()
         - CosineSimilarity(p2, StopGradient.Detach(z1)).Mean();

Methods

Detach(Tensor<T>)

Detaches a tensor from the computation graph, preventing gradient flow.

public static Tensor<T> Detach(Tensor<T> tensor)

Parameters

tensor Tensor<T>

The tensor to detach.

Returns

Tensor<T>

A copy of the tensor that won't contribute to gradients.

Remarks

For Beginners: This creates a copy of the tensor that acts as a "constant" during backpropagation. Gradients won't flow back through this tensor.

Detach(Vector<T>)

Applies stop-gradient to a vector.

public static Vector<T> Detach(Vector<T> vector)

Parameters

vector Vector<T>

The vector to detach.

Returns

Vector<T>

A copy of the vector that won't contribute to gradients.

DetachBatch(params Tensor<T>[])

Detaches a batch of tensors from the computation graph.

public static Tensor<T>[] DetachBatch(params Tensor<T>[] tensors)

Parameters

tensors Tensor<T>[]

The tensors to detach.

Returns

Tensor<T>[]

Copies of the tensors that won't contribute to gradients.

SymmetricLoss(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, Func<Tensor<T>, Tensor<T>, T>)

Computes the symmetric loss with stop-gradient for SimSiam-style training.

public static T SymmetricLoss(Tensor<T> prediction1, Tensor<T> target1, Tensor<T> prediction2, Tensor<T> target2, Func<Tensor<T>, Tensor<T>, T> lossFunction)

Parameters

prediction1 Tensor<T>

Prediction from view 1 through predictor.

target1 Tensor<T>

Target from view 1 (will be detached).

prediction2 Tensor<T>

Prediction from view 2 through predictor.

target2 Tensor<T>

Target from view 2 (will be detached).

lossFunction Func<Tensor<T>, Tensor<T>, T>

Function to compute loss between prediction and target.

Returns

T

The symmetric loss value.

Remarks

For Beginners: This implements the symmetric loss used in SimSiam:

L = 0.5 * (loss(p1, stop_grad(z2)) + loss(p2, stop_grad(z1)))

ZeroGrad(Tensor<T>)

Creates a zero-gradient version of a tensor for the backward pass.

public static Tensor<T> ZeroGrad(Tensor<T> tensor)

Parameters

tensor Tensor<T>

The tensor shape to match.

Returns

Tensor<T>

A zero tensor with the same shape (for gradient accumulation).