Table of Contents

Class MAE<T>

Namespace
AiDotNet.SelfSupervisedLearning
Assembly
AiDotNet.dll

MAE: Masked Autoencoder for Self-Supervised Vision Learning.

public class MAE<T> : SSLMethodBase<T>, ISSLMethod<T>

Type Parameters

T

The numeric type used for computations.

Inheritance
MAE<T>
Implements
Inherited Members

Remarks

For Beginners: MAE is a simple yet powerful self-supervised method. It randomly masks a large portion (75%) of image patches, encodes only the visible patches, and trains a decoder to reconstruct the original pixels of the masked patches.

Key innovations:

  • High masking ratio: 75% of patches are masked (vs ~15% in BERT)
  • Asymmetric encoder-decoder: Encoder only sees visible patches
  • Efficient training: Encoder processes only 25% of patches
  • Reconstruction target: Normalized pixel values of masked patches

Architecture:

Input → Patchify → Random Mask → Visible Encoder → Add mask tokens → Decoder → Reconstruct
Loss: MSE on masked patches only

Reference: He et al., "Masked Autoencoders Are Scalable Vision Learners" (CVPR 2022)

Constructors

MAE(INeuralNetwork<T>, INeuralNetwork<T>?, int, int, double, SSLConfig?)

Initializes a new instance of the MAE class.

public MAE(INeuralNetwork<T> encoder, INeuralNetwork<T>? decoder = null, int patchSize = 16, int imageSize = 224, double maskRatio = 0.75, SSLConfig? config = null)

Parameters

encoder INeuralNetwork<T>

The encoder (ViT) that processes visible patches.

decoder INeuralNetwork<T>

The decoder that reconstructs masked patches.

patchSize int

Size of each patch (default: 16).

imageSize int

Size of input images (default: 224).

maskRatio double

Ratio of patches to mask (default: 0.75).

config SSLConfig

Optional SSL configuration.

Properties

Category

Gets the category of this SSL method.

public override SSLMethodCategory Category { get; }

Property Value

SSLMethodCategory

Remarks

Categories include Contrastive, NonContrastive, Generative, and SelfDistillation.

MaskRatio

Gets the mask ratio (proportion of patches masked).

public double MaskRatio { get; }

Property Value

double

Name

Gets the name of this SSL method.

public override string Name { get; }

Property Value

string

Remarks

Examples: "SimCLR", "MoCo v2", "BYOL", "DINO", "MAE"

PatchSize

Gets the patch size.

public int PatchSize { get; }

Property Value

int

RequiresMemoryBank

Indicates whether this method requires a memory bank for negative samples.

public override bool RequiresMemoryBank { get; }

Property Value

bool

Remarks

For Beginners: Memory banks store embeddings from previous batches to use as negative samples in contrastive learning. MoCo uses this, SimCLR does not.

UsesMomentumEncoder

Indicates whether this method uses a momentum-updated encoder.

public override bool UsesMomentumEncoder { get; }

Property Value

bool

Remarks

For Beginners: A momentum encoder is a slowly-updated copy of the main encoder. Methods like MoCo, BYOL, and DINO use this to provide stable targets.

Methods

Create(INeuralNetwork<T>, INeuralNetwork<T>?, int, int, double)

Creates an MAE instance with default configuration.

public static MAE<T> Create(INeuralNetwork<T> encoder, INeuralNetwork<T>? decoder = null, int patchSize = 16, int imageSize = 224, double maskRatio = 0.75)

Parameters

encoder INeuralNetwork<T>
decoder INeuralNetwork<T>
patchSize int
imageSize int
maskRatio double

Returns

MAE<T>

TrainStepCore(Tensor<T>, SSLAugmentationContext<T>?)

Implementation-specific training step logic.

protected override SSLStepResult<T> TrainStepCore(Tensor<T> batch, SSLAugmentationContext<T>? augmentationContext)

Parameters

batch Tensor<T>

The input batch tensor.

augmentationContext SSLAugmentationContext<T>

Optional augmentation context.

Returns

SSLStepResult<T>

The result of the training step.