Class iBOT<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
iBOT: Image BERT Pre-Training with Online Tokenizer - combining DINO with masked image modeling.
public class iBOT<T> : TeacherStudentSSL<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
iBOT<T>
- Implements
-
ISSLMethod<T>
- Inherited Members
Remarks
For Beginners: iBOT combines the best of DINO (self-distillation) with masked image modeling (like MAE). It masks patches in the student view and predicts both the CLS token (like DINO) and the masked patches (like BERT for images).
Key innovations:
- Dual objective: CLS token distillation + masked patch prediction
- Online tokenizer: Uses teacher to provide targets for masked patches
- Shared architecture: Single network handles both objectives
- Better representations: Combines global (CLS) and local (patch) learning
Loss formula:
L = L_cls (DINO loss on CLS token) + λ * L_mim (masked patch prediction)
Reference: Zhou et al., "iBOT: Image BERT Pre-Training with Online Tokenizer" (ICLR 2022)
Constructors
iBOT(INeuralNetwork<T>, IMomentumEncoder<T>, IProjectorHead<T>, IProjectorHead<T>, int, double, double, SSLConfig?)
Initializes a new instance of the iBOT class.
public iBOT(INeuralNetwork<T> studentEncoder, IMomentumEncoder<T> teacherEncoder, IProjectorHead<T> studentProjector, IProjectorHead<T> teacherProjector, int outputDim = 8192, double mimWeight = 1, double maskRatio = 0.4, SSLConfig? config = null)
Parameters
studentEncoderINeuralNetwork<T>The student encoder (ViT required).
teacherEncoderIMomentumEncoder<T>The teacher encoder (momentum-updated copy).
studentProjectorIProjectorHead<T>Projection head for student.
teacherProjectorIProjectorHead<T>Projection head for teacher.
outputDimintOutput dimension of the projection heads.
mimWeightdoubleWeight for masked image modeling loss (default: 1.0).
maskRatiodoubleRatio of patches to mask (default: 0.4).
configSSLConfigOptional SSL configuration.
Properties
Category
Gets the category of this SSL method.
public override SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
MIMWeight
Gets the weight for masked image modeling loss.
public double MIMWeight { get; }
Property Value
MaskRatio
Gets the mask ratio for patches.
public double MaskRatio { get; }
Property Value
Name
Gets the name of this SSL method.
public override string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
Methods
Create(INeuralNetwork<T>, Func<INeuralNetwork<T>, INeuralNetwork<T>>, int, int, int, double, double)
Creates an iBOT instance with default configuration.
public static iBOT<T> Create(INeuralNetwork<T> encoder, Func<INeuralNetwork<T>, INeuralNetwork<T>> createEncoderCopy, int encoderOutputDim, int outputDim = 8192, int hiddenDim = 2048, double mimWeight = 1, double maskRatio = 0.4)
Parameters
encoderINeuralNetwork<T>createEncoderCopyFunc<INeuralNetwork<T>, INeuralNetwork<T>>encoderOutputDimintoutputDiminthiddenDimintmimWeightdoublemaskRatiodouble
Returns
- iBOT<T>
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.