Interface IVAEModel<T>
- Namespace
- AiDotNet.Interfaces
- Assembly
- AiDotNet.dll
Interface for Variational Autoencoder (VAE) models used in latent diffusion.
public interface IVAEModel<T> : IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations.
- Inherited Members
- Extension Methods
Remarks
VAEs are used in latent diffusion models to compress images into a lower-dimensional latent space where the diffusion process operates. This makes training and generation much more efficient than operating in pixel space.
For Beginners: A VAE is like a very smart image compressor and decompressor.
How it works:
- Encoder: Takes a full-size image (e.g., 512x512x3) and compresses it to a small latent (e.g., 64x64x4)
- Decoder: Takes the small latent and reconstructs a full-size image
- The compression is lossy but learned to preserve important visual information
Why use a VAE in diffusion?
- Full images are huge (512x512x3 = 786,432 values)
- Latents are small (64x64x4 = 16,384 values) - 48x smaller!
- Diffusion in latent space is much faster
- Quality remains high because the VAE learns what matters
Different VAE types:
- Standard VAE: Original Stable Diffusion VAE, 4 latent channels
- Tiny VAE: Faster but lower quality, good for previews
- Temporal VAE: Video-aware VAE that handles frame consistency
This interface extends IFullModel<T, TInput, TOutput> to provide all standard model capabilities (training, saving, loading, gradients, checkpointing, etc.).
Properties
DownsampleFactor
Gets the spatial downsampling factor.
int DownsampleFactor { get; }
Property Value
Remarks
The factor by which the VAE reduces spatial dimensions. Stable Diffusion uses 8x downsampling, so a 512x512 image becomes 64x64 latents.
InputChannels
Gets the number of input channels (image channels).
int InputChannels { get; }
Property Value
Remarks
Typically 3 for RGB images. Could be 1 for grayscale or 4 for RGBA.
LatentChannels
Gets the number of latent channels.
int LatentChannels { get; }
Property Value
Remarks
Standard Stable Diffusion VAEs use 4 latent channels. Some newer VAEs may use different values (e.g., 16 for certain architectures).
LatentScaleFactor
Gets the scale factor for latent values.
double LatentScaleFactor { get; }
Property Value
Remarks
A normalization factor applied to latent values. For Stable Diffusion, this is 0.18215, which normalizes the latent distribution to unit variance.
SupportsSlicing
Gets whether this VAE uses slicing for sequential processing.
bool SupportsSlicing { get; }
Property Value
Remarks
Slicing processes the batch one sample at a time to reduce memory. Trades speed for memory efficiency.
SupportsTiling
Gets whether this VAE uses tiling for memory-efficient encoding/decoding.
bool SupportsTiling { get; }
Property Value
Remarks
Tiling processes the image in overlapping patches to reduce memory usage when handling large images. Useful for high-resolution generation.
Methods
Decode(Tensor<T>)
Decodes a latent representation back to image space.
Tensor<T> Decode(Tensor<T> latent)
Parameters
latentTensor<T>The latent tensor [batch, latentChannels, latentHeight, latentWidth].
Returns
- Tensor<T>
The decoded image [batch, channels, heightdownFactor, widthdownFactor].
Remarks
For Beginners: This decompresses the latent back to an image: - Input: Small latent (64x64x4) - Output: Full-size image (512x512x3) - The image looks like the original but with minor differences due to compression
Encode(Tensor<T>, bool)
Encodes an image into the latent space.
Tensor<T> Encode(Tensor<T> image, bool sampleMode = true)
Parameters
imageTensor<T>The input image tensor [batch, channels, height, width].
sampleModeboolIf true, samples from the latent distribution. If false, returns the mean.
Returns
- Tensor<T>
The latent representation [batch, latentChannels, height/downFactor, width/downFactor].
Remarks
The VAE encoder outputs a distribution (mean and log variance). When sampleMode is true, we sample from this distribution using the reparameterization trick. When false, we just return the mean for deterministic encoding.
For Beginners: This compresses the image: - Input: Full-size image (512x512x3) - Output: Small latent representation (64x64x4) - The latent contains all the important information in a compressed form
EncodeWithDistribution(Tensor<T>)
Encodes and returns both mean and log variance (for training).
(Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> image)
Parameters
imageTensor<T>The input image tensor [batch, channels, height, width].
Returns
Remarks
Used during VAE training where we need both the mean and variance for computing the KL divergence loss.
Sample(Tensor<T>, Tensor<T>, int?)
Samples from the latent distribution using the reparameterization trick.
Tensor<T> Sample(Tensor<T> mean, Tensor<T> logVariance, int? seed = null)
Parameters
meanTensor<T>The mean of the latent distribution.
logVarianceTensor<T>The log variance of the latent distribution.
seedint?Optional random seed for reproducibility.
Returns
- Tensor<T>
A sample from the distribution: mean + std * epsilon.
Remarks
The reparameterization trick allows gradients to flow through the sampling operation: z = mean + exp(0.5 * logVariance) * epsilon, where epsilon ~ N(0, 1)
ScaleLatent(Tensor<T>)
Scales latent values for use in diffusion (applies LatentScaleFactor).
Tensor<T> ScaleLatent(Tensor<T> latent)
Parameters
latentTensor<T>The raw latent from encoding.
Returns
- Tensor<T>
Scaled latent values.
Remarks
Multiplies by LatentScaleFactor to normalize the latent distribution. This is necessary because VAE latents have a specific variance that diffusion models expect to be normalized.
SetSlicingEnabled(bool)
Enables or disables slicing mode.
void SetSlicingEnabled(bool enabled)
Parameters
enabledboolWhether to enable slicing.
SetTilingEnabled(bool)
Enables or disables tiling mode.
void SetTilingEnabled(bool enabled)
Parameters
enabledboolWhether to enable tiling.
UnscaleLatent(Tensor<T>)
Unscales latent values before decoding (inverts LatentScaleFactor).
Tensor<T> UnscaleLatent(Tensor<T> latent)
Parameters
latentTensor<T>The scaled latent from diffusion.
Returns
- Tensor<T>
Unscaled latent values ready for decoding.
Remarks
Divides by LatentScaleFactor to reverse the scaling before decoding.