Class BYOL<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
BYOL: Bootstrap Your Own Latent - Self-supervised learning without negative samples.
public class BYOL<T> : SSLMethodBase<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
BYOL<T>
- Implements
-
ISSLMethod<T>
- Inherited Members
Remarks
For Beginners: BYOL is a breakthrough method that learns representations without requiring negative samples. It uses an online network that learns to predict the output of a target network, which is updated as an exponential moving average (EMA) of the online network.
Key innovations:
- No negatives: Unlike SimCLR or MoCo, BYOL doesn't need negative samples
- Asymmetric architecture: Online has a predictor, target doesn't
- EMA target: Target network is a slow-moving average of online network
- Symmetric loss: Both views serve as online and target
Architecture:
Online: encoder → projector → predictor → p
Target: encoder → projector → z (stop-gradient)
Loss: MSE(normalize(p), normalize(z))
Why it doesn't collapse: The combination of the predictor (asymmetry), EMA updates (target moves slowly), and batch normalization prevents trivial solutions.
Reference: Grill et al., "Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning" (NeurIPS 2020)
Constructors
BYOL(INeuralNetwork<T>, IMomentumEncoder<T>, SymmetricProjector<T>, SymmetricProjector<T>, SSLConfig?)
Initializes a new instance of the BYOL class.
public BYOL(INeuralNetwork<T> encoder, IMomentumEncoder<T> targetEncoder, SymmetricProjector<T> onlineProjector, SymmetricProjector<T> targetProjector, SSLConfig? config = null)
Parameters
encoderINeuralNetwork<T>The online encoder network.
targetEncoderIMomentumEncoder<T>The target encoder (momentum-updated copy).
onlineProjectorSymmetricProjector<T>Symmetric projector with predictor for online network.
targetProjectorSymmetricProjector<T>Symmetric projector (no predictor) for target network.
configSSLConfigOptional SSL configuration.
Properties
Category
Gets the category of this SSL method.
public override SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
Name
Gets the name of this SSL method.
public override string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public override bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public override bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
Create(INeuralNetwork<T>, Func<INeuralNetwork<T>, INeuralNetwork<T>>, int, int, int)
Creates a BYOL instance with default configuration.
public static BYOL<T> Create(INeuralNetwork<T> encoder, Func<INeuralNetwork<T>, INeuralNetwork<T>> createEncoderCopy, int encoderOutputDim, int projectionDim = 256, int hiddenDim = 4096)
Parameters
encoderINeuralNetwork<T>The backbone encoder.
createEncoderCopyFunc<INeuralNetwork<T>, INeuralNetwork<T>>Function to create a copy of the encoder for target.
encoderOutputDimintOutput dimension of the encoder.
projectionDimintDimension of the projection space (default: 256).
hiddenDimintHidden dimension of the projector MLP (default: 4096).
Returns
- BYOL<T>
A configured BYOL instance.
OnEpochStart(int)
Signals the start of a new epoch.
public override void OnEpochStart(int epochNumber)
Parameters
epochNumberintThe current epoch number.
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.