Table of Contents

Class StandardVAE<T>

Namespace
AiDotNet.Diffusion.VAE
Assembly
AiDotNet.dll

Standard Variational Autoencoder for latent diffusion models.

public class StandardVAE<T> : VAEModelBase<T>, IVAEModel<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
StandardVAE<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

This implements a standard VAE architecture similar to Stable Diffusion's VAE, with an encoder that compresses images to latent space and a decoder that reconstructs images from latents.

For Beginners: The StandardVAE is like a very smart image compressor:

How it works:

  1. Encoder: Takes a 512x512x3 image and compresses it to 64x64x4 latent

    • That's 48x compression! (786,432 values -> 16,384 values)
    • Uses multiple layers of convolutions and downsampling
  2. Decoder: Takes the 64x64x4 latent and reconstructs a 512x512x3 image

    • Uses upsampling and convolutions to expand back to full size
    • The reconstruction isn't perfect but preserves important visual features

Why 4 latent channels?

  • The VAE learns to pack image information into 4 channels
  • Each channel captures different aspects (colors, edges, textures, etc.)
  • More channels = better quality but larger latent space

Why 8x downsampling?

  • Each side is reduced by 8 (512 -> 64)
  • This is the sweet spot between compression and quality
  • Smaller latents = faster diffusion, but potentially lower quality

Architecture details: - Input: [batch, 3, H, W] RGB image normalized to [-1, 1] - Encoder: ResBlocks with GroupNorm, downsampling via strided conv - Latent: [batch, 4, H/8, W/8] with mean and variance for sampling - Decoder: ResBlocks with GroupNorm, upsampling via transpose conv - Output: [batch, 3, H, W] reconstructed image

Constructors

StandardVAE(int, int, int, int[]?, int, double?, ILossFunction<T>?, int?)

Initializes a new instance of the StandardVAE class with default Stable Diffusion configuration.

public StandardVAE(int inputChannels = 3, int latentChannels = 4, int baseChannels = 128, int[]? channelMultipliers = null, int numResBlocksPerLevel = 2, double? latentScaleFactor = null, ILossFunction<T>? lossFunction = null, int? seed = null)

Parameters

inputChannels int

Number of input image channels (default: 3 for RGB).

latentChannels int

Number of latent channels (default: 4).

baseChannels int

Base channel count (default: 128).

channelMultipliers int[]

Channel multipliers per level (default: [1, 2, 4, 4]).

numResBlocksPerLevel int

Residual blocks per level (default: 2).

latentScaleFactor double?

Scale factor for latents (default: 0.18215).

lossFunction ILossFunction<T>

Optional loss function (default: MSE).

seed int?

Optional random seed for reproducibility.

Properties

DownsampleFactor

Gets the spatial downsampling factor.

public override int DownsampleFactor { get; }

Property Value

int

Remarks

The factor by which the VAE reduces spatial dimensions. Stable Diffusion uses 8x downsampling, so a 512x512 image becomes 64x64 latents.

InputChannels

Gets the number of input channels (image channels).

public override int InputChannels { get; }

Property Value

int

Remarks

Typically 3 for RGB images. Could be 1 for grayscale or 4 for RGBA.

LatentChannels

Gets the number of latent channels.

public override int LatentChannels { get; }

Property Value

int

Remarks

Standard Stable Diffusion VAEs use 4 latent channels. Some newer VAEs may use different values (e.g., 16 for certain architectures).

LatentScaleFactor

Gets the scale factor for latent values.

public override double LatentScaleFactor { get; }

Property Value

double

Remarks

A normalization factor applied to latent values. For Stable Diffusion, this is 0.18215, which normalizes the latent distribution to unit variance.

ParameterCount

Gets the number of parameters in the model.

public override int ParameterCount { get; }

Property Value

int

Remarks

This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.

SupportsSlicing

Gets whether this VAE uses slicing for sequential processing.

public override bool SupportsSlicing { get; }

Property Value

bool

Remarks

Slicing processes the batch one sample at a time to reduce memory. Trades speed for memory efficiency.

SupportsTiling

Gets whether this VAE uses tiling for memory-efficient encoding/decoding.

public override bool SupportsTiling { get; }

Property Value

bool

Remarks

Tiling processes the image in overlapping patches to reduce memory usage when handling large images. Useful for high-resolution generation.

Methods

Clone()

Creates a deep copy of the VAE model.

public override IVAEModel<T> Clone()

Returns

IVAEModel<T>

A new instance with the same parameters.

ComputeVAELoss(Tensor<T>, Tensor<T>, double)

Computes the VAE loss (reconstruction + KL divergence).

public T ComputeVAELoss(Tensor<T> image, Tensor<T> reconstruction, double klWeight = 1E-06)

Parameters

image Tensor<T>

Original input image.

reconstruction Tensor<T>

Reconstructed image.

klWeight double

Weight for KL divergence term (default: 1e-6).

Returns

T

Combined loss value.

Decode(Tensor<T>)

Decodes a latent representation back to image space.

public override Tensor<T> Decode(Tensor<T> latent)

Parameters

latent Tensor<T>

The latent tensor [batch, latentChannels, latentHeight, latentWidth].

Returns

Tensor<T>

The decoded image [batch, channels, heightdownFactor, widthdownFactor].

Remarks

For Beginners: This decompresses the latent back to an image: - Input: Small latent (64x64x4) - Output: Full-size image (512x512x3) - The image looks like the original but with minor differences due to compression

DecodeFromDiffusion(Tensor<T>)

Decodes a diffusion latent back to image space.

public Tensor<T> DecodeFromDiffusion(Tensor<T> latent)

Parameters

latent Tensor<T>

The latent from diffusion (already scaled).

Returns

Tensor<T>

The decoded image.

DeepCopy()

Creates a deep copy of this object.

public override IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

Encode(Tensor<T>, bool)

Encodes an image into the latent space.

public override Tensor<T> Encode(Tensor<T> image, bool sampleMode = true)

Parameters

image Tensor<T>

The input image tensor [batch, channels, height, width].

sampleMode bool

If true, samples from the latent distribution. If false, returns the mean.

Returns

Tensor<T>

The latent representation [batch, latentChannels, height/downFactor, width/downFactor].

Remarks

The VAE encoder outputs a distribution (mean and log variance). When sampleMode is true, we sample from this distribution using the reparameterization trick. When false, we just return the mean for deterministic encoding.

For Beginners: This compresses the image: - Input: Full-size image (512x512x3) - Output: Small latent representation (64x64x4) - The latent contains all the important information in a compressed form

EncodeForDiffusion(Tensor<T>, bool)

Encodes an image and applies latent scaling for use in diffusion.

public Tensor<T> EncodeForDiffusion(Tensor<T> image, bool sampleMode = true)

Parameters

image Tensor<T>

The input image tensor.

sampleMode bool

Whether to sample from the distribution.

Returns

Tensor<T>

Scaled latent representation.

EncodeWithDistribution(Tensor<T>)

Encodes and returns both mean and log variance (for training).

public override (Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> image)

Parameters

image Tensor<T>

The input image tensor [batch, channels, height, width].

Returns

(Tensor<T> grad1, Tensor<T> grad2)

Tuple of (mean, logVariance) tensors.

Remarks

Used during VAE training where we need both the mean and variance for computing the KL divergence loss.

GetParameters()

Gets the parameters that can be optimized.

public override Vector<T> GetParameters()

Returns

Vector<T>

SetParameters(Vector<T>)

Sets the model parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to set.

Remarks

This method allows direct modification of the model's internal parameters. This is useful for optimization algorithms that need to update parameters iteratively. If the length of parameters does not match ParameterCount, an ArgumentException should be thrown.

Exceptions

ArgumentException

Thrown when the length of parameters does not match ParameterCount.