Table of Contents

Class AutoencoderKL<T>

Namespace
AiDotNet.Diffusion.VAE
Assembly
AiDotNet.dll

KL-regularized Variational Autoencoder for latent diffusion models.

public class AutoencoderKL<T> : VAEModelBase<T>, IVAEModel<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
AutoencoderKL<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

AutoencoderKL is the standard VAE architecture used in Stable Diffusion and other latent diffusion models. It compresses high-resolution images to a compact latent representation while maintaining perceptual quality through KL-regularization.

For Beginners: AutoencoderKL is the "image compressor" used by Stable Diffusion.

Why use KL-regularized VAE?

  1. Compression: 512x512x3 image -> 64x64x4 latent (48x smaller!)
  2. KL-regularization: Keeps the latent space well-organized (Gaussian distribution)
  3. This organization makes diffusion work better in latent space

The "KL" in AutoencoderKL refers to Kullback-Leibler divergence, which measures how different the encoder's output distribution is from a standard normal. By minimizing KL divergence, we ensure the latent space is smooth and continuous.

Architecture:

    Image (512x512x3)
          │
          ├─→ VAEEncoder ─→ [mean, logvar] (64x64x8)
          │                        │
          │               Sample using reparameterization
          │                        │
          │                        ↓
          │              Latent z (64x64x4)
          │                        │
          │                 [Scale by 0.18215]
          │                        │
          │                        ↓
          │              Scaled latent (for diffusion)
          │                        │
          │                 [Unscale by 1/0.18215]
          │                        │
          │                        ↓
          │              Latent z (64x64x4)
          │                        │
          └────────────────→ VAEDecoder
                                   │
                                   ↓
                         Reconstructed Image (512x512x3)

Constructors

AutoencoderKL(int, int, int, int[]?, int, int, double?, int, ILossFunction<T>?, int?)

Initializes a new instance of the AutoencoderKL class with default Stable Diffusion configuration.

public AutoencoderKL(int inputChannels = 3, int latentChannels = 4, int baseChannels = 128, int[]? channelMults = null, int numResBlocks = 2, int numGroups = 32, double? latentScaleFactor = null, int inputSpatialSize = 512, ILossFunction<T>? lossFunction = null, int? seed = null)

Parameters

inputChannels int

Number of input image channels (default: 3 for RGB).

latentChannels int

Number of latent channels (default: 4).

baseChannels int

Base channel count (default: 128).

channelMults int[]

Channel multipliers per level (default: [1, 2, 4, 4]).

numResBlocks int

Number of residual blocks per level (default: 2).

numGroups int

Number of groups for GroupNorm (default: 32).

latentScaleFactor double?

Scale factor for latents (default: 0.18215).

inputSpatialSize int

Spatial size of input images (default: 512).

lossFunction ILossFunction<T>

Optional loss function (default: MSE).

seed int?

Optional random seed for reproducibility.

Remarks

For Beginners: Create an AutoencoderKL with sensible defaults for most use cases.

Default configuration matches Stable Diffusion v1.5/v2.1 VAE:

  • 3 RGB channels in/out
  • 4 latent channels
  • 8x spatial downsampling (512x512 -> 64x64)
  • Channel progression: 128 -> 256 -> 512 -> 512

For custom configurations:

  • Smaller latentChannels = more compression, potentially lower quality
  • Larger baseChannels = more capacity, but slower and more memory
  • More channelMults levels = more downsampling, smaller latents

Properties

Decoder

Gets the decoder component for direct access.

public VAEDecoder<T> Decoder { get; }

Property Value

VAEDecoder<T>

DownsampleFactor

Gets the spatial downsampling factor.

public override int DownsampleFactor { get; }

Property Value

int

Remarks

The factor by which the VAE reduces spatial dimensions. Stable Diffusion uses 8x downsampling, so a 512x512 image becomes 64x64 latents.

Encoder

Gets the encoder component for direct access.

public VAEEncoder<T> Encoder { get; }

Property Value

VAEEncoder<T>

InputChannels

Gets the number of input channels (image channels).

public override int InputChannels { get; }

Property Value

int

Remarks

Typically 3 for RGB images. Could be 1 for grayscale or 4 for RGBA.

LatentChannels

Gets the number of latent channels.

public override int LatentChannels { get; }

Property Value

int

Remarks

Standard Stable Diffusion VAEs use 4 latent channels. Some newer VAEs may use different values (e.g., 16 for certain architectures).

LatentScaleFactor

Gets the scale factor for latent values.

public override double LatentScaleFactor { get; }

Property Value

double

Remarks

A normalization factor applied to latent values. For Stable Diffusion, this is 0.18215, which normalizes the latent distribution to unit variance.

ParameterCount

Gets the number of parameters in the model.

public override int ParameterCount { get; }

Property Value

int

Remarks

This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.

SupportsSlicing

Gets whether this VAE uses slicing for sequential processing.

public override bool SupportsSlicing { get; }

Property Value

bool

Remarks

Slicing processes the batch one sample at a time to reduce memory. Trades speed for memory efficiency.

SupportsTiling

Gets whether this VAE uses tiling for memory-efficient encoding/decoding.

public override bool SupportsTiling { get; }

Property Value

bool

Remarks

Tiling processes the image in overlapping patches to reduce memory usage when handling large images. Useful for high-resolution generation.

Methods

Clone()

Creates a deep copy of the VAE model.

public override IVAEModel<T> Clone()

Returns

IVAEModel<T>

A new instance with the same parameters.

ComputeVAELoss(Tensor<T>, Tensor<T>, double)

Computes the VAE loss (reconstruction + KL divergence).

public T ComputeVAELoss(Tensor<T> image, Tensor<T> reconstruction, double klWeight = 1E-06)

Parameters

image Tensor<T>

Original input image.

reconstruction Tensor<T>

Reconstructed image from Forward().

klWeight double

Weight for KL divergence term (default: 1e-6).

Returns

T

Combined loss value.

Remarks

For Beginners: The VAE loss has two parts:

  1. Reconstruction loss: How different is the output from the input?

    • Uses MSE (mean squared error) by default
    • Lower = better reconstruction
  2. KL divergence loss: How different is the latent distribution from N(0,1)?

    • Regularizes the latent space to be smooth
    • Lower = more organized latent space

The klWeight controls the trade-off:

  • Higher klWeight = more regularized latent space, potentially blurrier reconstructions
  • Lower klWeight = sharper reconstructions, but less organized latent space

Default 1e-6 is very small because we prioritize reconstruction quality for diffusion applications.

Decode(Tensor<T>)

Decodes a latent representation back to image space.

public override Tensor<T> Decode(Tensor<T> latent)

Parameters

latent Tensor<T>

The latent tensor [batch, latentChannels, latentHeight, latentWidth].

Returns

Tensor<T>

The decoded image [batch, channels, heightdownFactor, widthdownFactor].

Remarks

For Beginners: This decompresses the latent back to an image: - Input: Small latent (64x64x4) - Output: Full-size image (512x512x3) - The image looks like the original but with minor differences due to compression

DecodeFromDiffusion(Tensor<T>)

Decodes a diffusion latent back to image space.

public Tensor<T> DecodeFromDiffusion(Tensor<T> latent)

Parameters

latent Tensor<T>

The scaled latent from diffusion.

Returns

Tensor<T>

The decoded image.

Remarks

For Beginners: Use this method to convert diffusion output to images.

Steps:

  1. Unscale the latent (divide by scale factor)
  2. Decode through the VAE decoder
  3. Result is an image in [-1, 1] range

To display/save, convert from [-1, 1] to [0, 255]: pixel = (value + 1) * 127.5

DeepCopy()

Creates a deep copy of this object.

public override IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

Encode(Tensor<T>, bool)

Encodes an image into the latent space.

public override Tensor<T> Encode(Tensor<T> image, bool sampleMode = true)

Parameters

image Tensor<T>

The input image tensor [batch, channels, height, width].

sampleMode bool

If true, samples from the latent distribution. If false, returns the mean.

Returns

Tensor<T>

The latent representation [batch, latentChannels, height/downFactor, width/downFactor].

Remarks

The VAE encoder outputs a distribution (mean and log variance). When sampleMode is true, we sample from this distribution using the reparameterization trick. When false, we just return the mean for deterministic encoding.

For Beginners: This compresses the image: - Input: Full-size image (512x512x3) - Output: Small latent representation (64x64x4) - The latent contains all the important information in a compressed form

EncodeForDiffusion(Tensor<T>, bool)

Encodes an image and applies latent scaling for use in diffusion.

public Tensor<T> EncodeForDiffusion(Tensor<T> image, bool sampleMode = true)

Parameters

image Tensor<T>

The input image tensor.

sampleMode bool

Whether to sample from the distribution (default: true).

Returns

Tensor<T>

Scaled latent representation ready for diffusion.

Remarks

For Beginners: Use this method when preparing images for diffusion.

The latent scaling is important because:

  1. It normalizes the latent distribution to unit variance
  2. This helps the diffusion model work with consistent noise levels
  3. The scale factor (0.18215) was empirically determined for SD VAE

EncodeWithDistribution(Tensor<T>)

Encodes and returns both mean and log variance (for training).

public override (Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> image)

Parameters

image Tensor<T>

The input image tensor [batch, channels, height, width].

Returns

(Tensor<T> grad1, Tensor<T> grad2)

Tuple of (mean, logVariance) tensors.

Remarks

Used during VAE training where we need both the mean and variance for computing the KL divergence loss.

Forward(Tensor<T>)

Performs a full forward pass: encode -> sample -> decode.

public Tensor<T> Forward(Tensor<T> image)

Parameters

image Tensor<T>

Input image.

Returns

Tensor<T>

Reconstructed image.

GetParameters()

Gets the parameters that can be optimized.

public override Vector<T> GetParameters()

Returns

Vector<T>

Lightweight()

Creates a lightweight AutoencoderKL for testing/experimentation.

public static AutoencoderKL<T> Lightweight()

Returns

AutoencoderKL<T>

LoadState(Stream)

Loads the model's state (parameters and configuration) from a stream.

public override void LoadState(Stream stream)

Parameters

stream Stream

The stream to read the model state from.

Remarks

This method deserializes model state that was previously saved with SaveState, restoring all parameters and configuration to recreate the saved model state.

For Beginners: This is like loading a saved game.

When you call LoadState:

  • All the parameters are read from the stream
  • The model is configured to match the saved architecture
  • The model becomes identical to when SaveState was called

After loading, the model can make predictions using the restored parameters.

Stream Handling: - The stream position will be advanced by the number of bytes read - The stream is not closed (caller must dispose) - Stream data must match the format written by SaveState

Versioning: Implementations should consider: - Including format version number in serialized data - Validating compatibility before deserialization - Providing migration paths for old formats when possible

Usage:

// Load from file
using var stream = File.OpenRead("model.bin");
model.LoadState(stream);

Important: The stream must contain state data saved by SaveState from a compatible model (same architecture and numeric type).

Exceptions

ArgumentNullException

Thrown when stream is null.

ArgumentException

Thrown when stream is not readable or contains invalid data.

InvalidOperationException

Thrown when deserialization fails or data is incompatible with model architecture.

ResetState()

Resets the internal state of encoder and decoder.

public void ResetState()

SDXL()

Creates an AutoencoderKL matching SDXL configuration.

public static AutoencoderKL<T> SDXL()

Returns

AutoencoderKL<T>

SaveState(Stream)

Saves the model's current state (parameters and configuration) to a stream.

public override void SaveState(Stream stream)

Parameters

stream Stream

The stream to write the model state to.

Remarks

This method serializes all the information needed to recreate the model's current state, including trained parameters, layer configurations, and any internal state variables.

For Beginners: This is like creating a snapshot of your trained model.

When you call SaveState:

  • All the learned parameters (weights and biases) are written to the stream
  • The model's architecture information is saved
  • Any other internal state (like normalization statistics) is preserved

You can later use LoadState to restore the model to this exact state.

Stream Handling: - The stream position will be advanced by the number of bytes written - The stream is flushed but not closed (caller must dispose) - For file-based persistence, wrap in File.Create/FileStream

Usage:

// Save to file
using var stream = File.Create("model.bin");
model.SaveState(stream);

Exceptions

ArgumentNullException

Thrown when stream is null.

ArgumentException

Thrown when stream is not writable.

InvalidOperationException

Thrown when model state cannot be serialized (e.g., uninitialized model).

SetParameters(Vector<T>)

Sets the model parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to set.

Remarks

This method allows direct modification of the model's internal parameters. This is useful for optimization algorithms that need to update parameters iteratively. If the length of parameters does not match ParameterCount, an ArgumentException should be thrown.

Exceptions

ArgumentException

Thrown when the length of parameters does not match ParameterCount.

StableDiffusionV1()

Creates a default AutoencoderKL matching Stable Diffusion v1.5 configuration.

public static AutoencoderKL<T> StableDiffusionV1()

Returns

AutoencoderKL<T>

Train(Tensor<T>, Tensor<T>)

Trains the VAE on a single image.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

Input image to reconstruct.

expectedOutput Tensor<T>

Target output (usually same as input for VAE).