Class StopGradient<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
Provides stop-gradient operations for self-supervised learning.
public static class StopGradient<T>
Type Parameters
TThe numeric type used for computations (typically float or double).
- Inheritance
-
StopGradient<T>
- Inherited Members
Remarks
For Beginners: Stop-gradient (also called "detach" in PyTorch) prevents gradients from flowing through a tensor during backpropagation. This is crucial for several SSL methods:
- SimSiam: Stop-gradient on one branch prevents collapse without momentum encoder
- BYOL: Target network outputs are detached (no gradients to momentum encoder)
- MoCo: Memory bank entries and momentum encoder outputs are detached
Why stop-gradient?
Without stop-gradient, the model could "cheat" by making both branches output constants, resulting in representation collapse. Stop-gradient forces asymmetry that prevents this.
Example usage:
// SimSiam: asymmetric loss with stop-gradient
var z1 = encoder(x1); // Online branch
var z2 = encoder(x2); // Online branch
var p1 = predictor(z1);
var p2 = predictor(z2);
// Loss with stop-gradient - gradients only flow through predictor side
var loss = -CosineSimilarity(p1, StopGradient.Detach(z2)).Mean()
- CosineSimilarity(p2, StopGradient.Detach(z1)).Mean();
Methods
Detach(Tensor<T>)
Detaches a tensor from the computation graph, preventing gradient flow.
public static Tensor<T> Detach(Tensor<T> tensor)
Parameters
tensorTensor<T>The tensor to detach.
Returns
- Tensor<T>
A copy of the tensor that won't contribute to gradients.
Remarks
For Beginners: This creates a copy of the tensor that acts as a "constant" during backpropagation. Gradients won't flow back through this tensor.
Detach(Vector<T>)
Applies stop-gradient to a vector.
public static Vector<T> Detach(Vector<T> vector)
Parameters
vectorVector<T>The vector to detach.
Returns
- Vector<T>
A copy of the vector that won't contribute to gradients.
DetachBatch(params Tensor<T>[])
Detaches a batch of tensors from the computation graph.
public static Tensor<T>[] DetachBatch(params Tensor<T>[] tensors)
Parameters
tensorsTensor<T>[]The tensors to detach.
Returns
- Tensor<T>[]
Copies of the tensors that won't contribute to gradients.
SymmetricLoss(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, Func<Tensor<T>, Tensor<T>, T>)
Computes the symmetric loss with stop-gradient for SimSiam-style training.
public static T SymmetricLoss(Tensor<T> prediction1, Tensor<T> target1, Tensor<T> prediction2, Tensor<T> target2, Func<Tensor<T>, Tensor<T>, T> lossFunction)
Parameters
prediction1Tensor<T>Prediction from view 1 through predictor.
target1Tensor<T>Target from view 1 (will be detached).
prediction2Tensor<T>Prediction from view 2 through predictor.
target2Tensor<T>Target from view 2 (will be detached).
lossFunctionFunc<Tensor<T>, Tensor<T>, T>Function to compute loss between prediction and target.
Returns
- T
The symmetric loss value.
Remarks
For Beginners: This implements the symmetric loss used in SimSiam:
L = 0.5 * (loss(p1, stop_grad(z2)) + loss(p2, stop_grad(z1)))
ZeroGrad(Tensor<T>)
Creates a zero-gradient version of a tensor for the backward pass.
public static Tensor<T> ZeroGrad(Tensor<T> tensor)
Parameters
tensorTensor<T>The tensor shape to match.
Returns
- Tensor<T>
A zero tensor with the same shape (for gradient accumulation).