Class MAE<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
MAE: Masked Autoencoder for Self-Supervised Vision Learning.
public class MAE<T> : SSLMethodBase<T>, ISSLMethod<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
MAE<T>
- Implements
-
ISSLMethod<T>
- Inherited Members
Remarks
For Beginners: MAE is a simple yet powerful self-supervised method. It randomly masks a large portion (75%) of image patches, encodes only the visible patches, and trains a decoder to reconstruct the original pixels of the masked patches.
Key innovations:
- High masking ratio: 75% of patches are masked (vs ~15% in BERT)
- Asymmetric encoder-decoder: Encoder only sees visible patches
- Efficient training: Encoder processes only 25% of patches
- Reconstruction target: Normalized pixel values of masked patches
Architecture:
Input → Patchify → Random Mask → Visible Encoder → Add mask tokens → Decoder → Reconstruct
Loss: MSE on masked patches only
Reference: He et al., "Masked Autoencoders Are Scalable Vision Learners" (CVPR 2022)
Constructors
MAE(INeuralNetwork<T>, INeuralNetwork<T>?, int, int, double, SSLConfig?)
Initializes a new instance of the MAE class.
public MAE(INeuralNetwork<T> encoder, INeuralNetwork<T>? decoder = null, int patchSize = 16, int imageSize = 224, double maskRatio = 0.75, SSLConfig? config = null)
Parameters
encoderINeuralNetwork<T>The encoder (ViT) that processes visible patches.
decoderINeuralNetwork<T>The decoder that reconstructs masked patches.
patchSizeintSize of each patch (default: 16).
imageSizeintSize of input images (default: 224).
maskRatiodoubleRatio of patches to mask (default: 0.75).
configSSLConfigOptional SSL configuration.
Properties
Category
Gets the category of this SSL method.
public override SSLMethodCategory Category { get; }
Property Value
Remarks
Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.
MaskRatio
Gets the mask ratio (proportion of patches masked).
public double MaskRatio { get; }
Property Value
Name
Gets the name of this SSL method.
public override string Name { get; }
Property Value
Remarks
Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"
PatchSize
Gets the patch size.
public int PatchSize { get; }
Property Value
RequiresMemoryBank
Indicates whether this method requires a memory bank for negative samples.
public override bool RequiresMemoryBank { get; }
Property Value
Remarks
For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.
UsesMomentumEncoder
Indicates whether this method uses a momentum-updated encoder.
public override bool UsesMomentumEncoder { get; }
Property Value
Remarks
For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.
Methods
Create(INeuralNetwork<T>, INeuralNetwork<T>?, int, int, double)
Creates an MAE instance with default configuration.
public static MAE<T> Create(INeuralNetwork<T> encoder, INeuralNetwork<T>? decoder = null, int patchSize = 16, int imageSize = 224, double maskRatio = 0.75)
Parameters
encoderINeuralNetwork<T>decoderINeuralNetwork<T>patchSizeintimageSizeintmaskRatiodouble
Returns
- MAE<T>
TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)
Implementation-specific training step logic.
protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)
Parameters
batchTensor<T>The input batch tensor.
augmentationContextSSLAugmentationContext<T>Optional augmentation context.
Returns
- SSLStepResult<T>
The result of the training step.