Table of Contents

Class VAEModelBase<T>

Namespace
AiDotNet.Diffusion.VAE
Assembly
AiDotNet.dll

Base class for Variational Autoencoder (VAE) models used in latent diffusion.

public abstract class VAEModelBase<T> : IVAEModel<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
VAEModelBase<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Derived
Inherited Members
Extension Methods

Remarks

This abstract base class provides common functionality for all VAE implementations, including encoding, decoding, sampling, and latent scaling operations.

For Beginners: This is the foundation for all VAE models in the library. VAEs compress images to a small latent representation and decompress them back. They are essential for efficient latent diffusion models like Stable Diffusion.

Constructors

VAEModelBase(ILossFunction<T>?, int?)

Initializes a new instance of the VAEModelBase class.

protected VAEModelBase(ILossFunction<T>? lossFunction = null, int? seed = null)

Parameters

lossFunction ILossFunction<T>

Optional custom loss function. Defaults to MSE.

seed int?

Optional random seed for reproducibility.

Fields

LossFunction

The loss function used for training.

protected readonly ILossFunction<T> LossFunction

Field Value

ILossFunction<T>

NumOps

Provides numeric operations for the specific type T.

protected static readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

RandomGenerator

Random number generator for sampling operations.

protected Random RandomGenerator

Field Value

Random

SlicingEnabled

Whether slicing mode is enabled for sequential processing.

protected bool SlicingEnabled

Field Value

bool

TilingEnabled

Whether tiling mode is enabled for memory-efficient processing.

protected bool TilingEnabled

Field Value

bool

Properties

DefaultLossFunction

Gets the default loss function used by this model for gradient computation.

public ILossFunction<T> DefaultLossFunction { get; }

Property Value

ILossFunction<T>

Remarks

This loss function is used when calling ComputeGradients(TInput, TOutput, ILossFunction<T>?) without explicitly providing a loss function. It represents the model's primary training objective.

For Beginners: The loss function tells the model "what counts as a mistake". For example: - For regression (predicting numbers): Mean Squared Error measures how far predictions are from actual values - For classification (predicting categories): Cross Entropy measures how confident the model is in the right category

This property provides a sensible default so you don't have to specify the loss function every time, but you can still override it if needed for special cases.

Distributed Training: In distributed training, all workers use the same loss function to ensure consistent gradient computation. The default loss function is automatically used when workers compute local gradients.

Exceptions

InvalidOperationException

Thrown if accessed before the model has been configured with a loss function.

DownsampleFactor

Gets the spatial downsampling factor.

public abstract int DownsampleFactor { get; }

Property Value

int

Remarks

The factor by which the VAE reduces spatial dimensions. Stable Diffusion uses 8x downsampling, so a 512x512 image becomes 64x64 latents.

InputChannels

Gets the number of input channels (image channels).

public abstract int InputChannels { get; }

Property Value

int

Remarks

Typically 3 for RGB images. Could be 1 for grayscale or 4 for RGBA.

LatentChannels

Gets the number of latent channels.

public abstract int LatentChannels { get; }

Property Value

int

Remarks

Standard Stable Diffusion VAEs use 4 latent channels. Some newer VAEs may use different values (e.g., 16 for certain architectures).

LatentScaleFactor

Gets the scale factor for latent values.

public abstract double LatentScaleFactor { get; }

Property Value

double

Remarks

A normalization factor applied to latent values. For Stable Diffusion, this is 0.18215, which normalizes the latent distribution to unit variance.

ParameterCount

Gets the number of parameters in the model.

public abstract int ParameterCount { get; }

Property Value

int

Remarks

This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.

SupportsJitCompilation

Gets whether this model currently supports JIT compilation.

public virtual bool SupportsJitCompilation { get; }

Property Value

bool

True if the model can be JIT compiled, false otherwise.

Remarks

Some models may not support JIT compilation due to: - Dynamic graph structure (changes based on input) - Lack of computation graph representation - Use of operations not yet supported by the JIT compiler

For Beginners: This tells you whether this specific model can benefit from JIT compilation.

Models return false if they:

  • Use layer-based architecture without graph export (e.g., current neural networks)
  • Have control flow that changes based on input data
  • Use operations the JIT compiler doesn't understand yet

In these cases, the model will still work normally, just without JIT acceleration.

SupportsSlicing

Gets whether this VAE uses slicing for sequential processing.

public virtual bool SupportsSlicing { get; }

Property Value

bool

Remarks

Slicing processes the batch one sample at a time to reduce memory. Trades speed for memory efficiency.

SupportsTiling

Gets whether this VAE uses tiling for memory-efficient encoding/decoding.

public virtual bool SupportsTiling { get; }

Property Value

bool

Remarks

Tiling processes the image in overlapping patches to reduce memory usage when handling large images. Useful for high-resolution generation.

Methods

ApplyGradients(Vector<T>, T)

Applies pre-computed gradients to update the model parameters.

public virtual void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>

The gradient vector to apply.

learningRate T

The learning rate for the update.

Remarks

Updates parameters using: θ = θ - learningRate * gradients

For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.

Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.

Clone()

Creates a deep copy of the VAE model.

public abstract IVAEModel<T> Clone()

Returns

IVAEModel<T>

A new instance with the same parameters.

ComputeGradients(Tensor<T>, Tensor<T>, ILossFunction<T>?)

Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.

public virtual Vector<T> ComputeGradients(Tensor<T> input, Tensor<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Tensor<T>

The input data.

target Tensor<T>

The target/expected output.

lossFunction ILossFunction<T>

The loss function to use for gradient computation. If null, uses the model's default loss function.

Returns

Vector<T>

A vector containing gradients with respect to all model parameters.

Remarks

This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.

Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.

For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.

For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)

Exceptions

InvalidOperationException

If lossFunction is null and the model has no default loss function.

ComputeKLDivergence(Tensor<T>, Tensor<T>)

Computes the KL divergence loss for VAE training.

protected virtual T ComputeKLDivergence(Tensor<T> mean, Tensor<T> logVariance)

Parameters

mean Tensor<T>

The mean of the latent distribution.

logVariance Tensor<T>

The log variance of the latent distribution.

Returns

T

The KL divergence loss value.

Remarks

KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

Decode(Tensor<T>)

Decodes a latent representation back to image space.

public abstract Tensor<T> Decode(Tensor<T> latent)

Parameters

latent Tensor<T>

The latent tensor [batch, latentChannels, latentHeight, latentWidth].

Returns

Tensor<T>

The decoded image [batch, channels, heightdownFactor, widthdownFactor].

Remarks

For Beginners: This decompresses the latent back to an image: - Input: Small latent (64x64x4) - Output: Full-size image (512x512x3) - The image looks like the original but with minor differences due to compression

DeepCopy()

Creates a deep copy of this object.

public abstract IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

Deserialize(byte[])

Loads a previously serialized model from binary data.

public virtual void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.

For Beginners: This is like opening a saved file to continue your work.

When you call this method:

  • You provide the binary data (bytes) that was previously created by Serialize
  • The model rebuilds itself using this data
  • After deserializing, the model is exactly as it was when serialized
  • It's ready to make predictions without needing to be trained again

For example:

  • You download a pre-trained model file for detecting spam emails
  • You deserialize this file into your application
  • Immediately, your application can detect spam without any training
  • The model has all the knowledge that was built into it by its original creator

This is particularly useful when:

  • You want to use a model that took days to train
  • You need to deploy the same model across multiple devices
  • You're creating an application that non-technical users will use

Think of it like installing the brain of a trained expert directly into your application.

Encode(Tensor<T>, bool)

Encodes an image into the latent space.

public abstract Tensor<T> Encode(Tensor<T> image, bool sampleMode = true)

Parameters

image Tensor<T>

The input image tensor [batch, channels, height, width].

sampleMode bool

If true, samples from the latent distribution. If false, returns the mean.

Returns

Tensor<T>

The latent representation [batch, latentChannels, height/downFactor, width/downFactor].

Remarks

The VAE encoder outputs a distribution (mean and log variance). When sampleMode is true, we sample from this distribution using the reparameterization trick. When false, we just return the mean for deterministic encoding.

For Beginners: This compresses the image: - Input: Full-size image (512x512x3) - Output: Small latent representation (64x64x4) - The latent contains all the important information in a compressed form

EncodeWithDistribution(Tensor<T>)

Encodes and returns both mean and log variance (for training).

public abstract (Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> image)

Parameters

image Tensor<T>

The input image tensor [batch, channels, height, width].

Returns

(Tensor<T> grad1, Tensor<T> grad2)

Tuple of (mean, logVariance) tensors.

Remarks

Used during VAE training where we need both the mean and variance for computing the KL divergence loss.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the model's computation graph for JIT compilation.

public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes (parameters).

Returns

ComputationNode<T>

The output computation node representing the model's prediction.

Remarks

This method should construct a computation graph representing the model's forward pass. The graph should use placeholder input nodes that will be filled with actual data during execution.

For Beginners: This method creates a "recipe" of your model's calculations that the JIT compiler can optimize.

The method should:

  1. Create placeholder nodes for inputs (features, parameters)
  2. Build the computation graph using TensorOperations
  3. Return the final output node
  4. Add all input nodes to the inputNodes list (in order)

Example for a simple linear model (y = Wx + b):

public ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
{
    // Create placeholder inputs
    var x = TensorOperations<T>.Variable(new Tensor<T>(InputShape), "x");
    var W = TensorOperations<T>.Variable(Weights, "W");
    var b = TensorOperations<T>.Variable(Bias, "b");

    // Add inputs in order
    inputNodes.Add(x);
    inputNodes.Add(W);
    inputNodes.Add(b);

    // Build graph: y = Wx + b
    var matmul = TensorOperations<T>.MatMul(x, W);
    var output = TensorOperations<T>.Add(matmul, b);

    return output;
}

The JIT compiler will then:

  • Optimize the graph (fuse operations, eliminate dead code)
  • Compile it to fast native code
  • Cache the compiled version for reuse

GetActiveFeatureIndices()

Gets the indices of features that are actively used by this model.

public virtual IEnumerable<int> GetActiveFeatureIndices()

Returns

IEnumerable<int>

GetFeatureImportance()

Gets the feature importance scores.

public virtual Dictionary<string, T> GetFeatureImportance()

Returns

Dictionary<string, T>

GetModelMetadata()

Retrieves metadata and performance metrics about the trained model.

public virtual ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

An object containing metadata and performance metrics about the trained model.

Remarks

This method provides information about the model's structure, parameters, and performance metrics.

For Beginners: Model metadata is like a report card for your machine learning model.

Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.

This information typically includes:

  • Accuracy measures: How well does the model's predictions match actual values?
  • Error metrics: How far off are the model's predictions on average?
  • Model parameters: What patterns did the model learn from the data?
  • Training information: How long did training take? How many iterations were needed?

For example, in a house price prediction model, metadata might include:

  • Average prediction error (e.g., off by $15,000 on average)
  • How strongly each feature (bedrooms, location) influences the prediction
  • How well the model fits the training data

This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.

GetParameters()

Gets the parameters that can be optimized.

public abstract Vector<T> GetParameters()

Returns

Vector<T>

IsFeatureUsed(int)

Checks if a specific feature is used by this model.

public virtual bool IsFeatureUsed(int featureIndex)

Parameters

featureIndex int

Returns

bool

LoadModel(string)

Loads the model from a file.

public virtual void LoadModel(string filePath)

Parameters

filePath string

The path to the file containing the saved model.

Remarks

This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.

For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.

Exceptions

FileNotFoundException

Thrown when the specified file does not exist.

IOException

Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.

LoadState(Stream)

Loads the model's state (parameters and configuration) from a stream.

public virtual void LoadState(Stream stream)

Parameters

stream Stream

The stream to read the model state from.

Remarks

This method deserializes model state that was previously saved with SaveState, restoring all parameters and configuration to recreate the saved model state.

For Beginners: This is like loading a saved game.

When you call LoadState:

  • All the parameters are read from the stream
  • The model is configured to match the saved architecture
  • The model becomes identical to when SaveState was called

After loading, the model can make predictions using the restored parameters.

Stream Handling: - The stream position will be advanced by the number of bytes read - The stream is not closed (caller must dispose) - Stream data must match the format written by SaveState

Versioning: Implementations should consider: - Including format version number in serialized data - Validating compatibility before deserialization - Providing migration paths for old formats when possible

Usage:

// Load from file
using var stream = File.OpenRead("model.bin");
model.LoadState(stream);

Important: The stream must contain state data saved by SaveState from a compatible model (same architecture and numeric type).

Exceptions

ArgumentNullException

Thrown when stream is null.

ArgumentException

Thrown when stream is not readable or contains invalid data.

InvalidOperationException

Thrown when deserialization fails or data is incompatible with model architecture.

Predict(Tensor<T>)

Uses the trained model to make predictions for new input data.

public virtual Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

A matrix where each row represents a new example to predict and each column represents a feature.

Returns

Tensor<T>

A vector containing the predicted values for each input example.

Remarks

After training, this method applies the learned patterns to new data to predict outcomes.

For Beginners: Prediction is when the model uses what it learned to make educated guesses about new information.

Continuing the fruit identification example:

  • After learning from many examples, the child (model) can now identify new fruits they haven't seen before
  • They look at the color, shape, and size to make their best guess

In machine learning:

  • You give the model new data it hasn't seen during training
  • The model applies the patterns it learned to make predictions
  • The output is the model's best estimate based on its training

For example, in a house price prediction model:

  • You provide features of a new house (square footage, bedrooms, location)
  • The model predicts what price that house might sell for

This method is used after training is complete, when you want to apply your model to real-world data.

Sample(Tensor<T>, Tensor<T>, int?)

Samples from the latent distribution using the reparameterization trick.

public virtual Tensor<T> Sample(Tensor<T> mean, Tensor<T> logVariance, int? seed = null)

Parameters

mean Tensor<T>

The mean of the latent distribution.

logVariance Tensor<T>

The log variance of the latent distribution.

seed int?

Optional random seed for reproducibility.

Returns

Tensor<T>

A sample from the distribution: mean + std * epsilon.

Remarks

The reparameterization trick allows gradients to flow through the sampling operation: z = mean + exp(0.5 * logVariance) * epsilon, where epsilon ~ N(0, 1)

SampleNoise(int[], Random?)

Samples random noise from a standard normal distribution.

protected virtual Tensor<T> SampleNoise(int[] shape, Random? rng = null)

Parameters

shape int[]

The shape of the noise tensor.

rng Random

Optional random number generator.

Returns

Tensor<T>

A tensor of random noise values.

SaveModel(string)

Saves the model to a file.

public virtual void SaveModel(string filePath)

Parameters

filePath string

The path where the model should be saved.

Remarks

This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.

For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.

Exceptions

IOException

Thrown when an I/O error occurs while writing to the file.

UnauthorizedAccessException

Thrown when the caller does not have the required permission to write to the specified file path.

SaveState(Stream)

Saves the model's current state (parameters and configuration) to a stream.

public virtual void SaveState(Stream stream)

Parameters

stream Stream

The stream to write the model state to.

Remarks

This method serializes all the information needed to recreate the model's current state, including trained parameters, layer configurations, and any internal state variables.

For Beginners: This is like creating a snapshot of your trained model.

When you call SaveState:

  • All the learned parameters (weights and biases) are written to the stream
  • The model's architecture information is saved
  • Any other internal state (like normalization statistics) is preserved

You can later use LoadState to restore the model to this exact state.

Stream Handling: - The stream position will be advanced by the number of bytes written - The stream is flushed but not closed (caller must dispose) - For file-based persistence, wrap in File.Create/FileStream

Usage:

// Save to file
using var stream = File.Create("model.bin");
model.SaveState(stream);

Exceptions

ArgumentNullException

Thrown when stream is null.

ArgumentException

Thrown when stream is not writable.

InvalidOperationException

Thrown when model state cannot be serialized (e.g., uninitialized model).

ScaleLatent(Tensor<T>)

Scales latent values for use in diffusion (applies LatentScaleFactor).

public virtual Tensor<T> ScaleLatent(Tensor<T> latent)

Parameters

latent Tensor<T>

The raw latent from encoding.

Returns

Tensor<T>

Scaled latent values.

Remarks

Multiplies by LatentScaleFactor to normalize the latent distribution. This is necessary because VAE latents have a specific variance that diffusion models expect to be normalized.

Serialize()

Converts the current state of a machine learning model into a binary format.

public virtual byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.

For Beginners: This is like exporting your work to a file.

When you call this method:

  • The model's current state (all its learned patterns and parameters) is captured
  • This information is converted into a compact binary format (bytes)
  • You can then save these bytes to a file, database, or send them over a network

For example:

  • After training a model to recognize cats vs. dogs in images
  • You can serialize the model to save all its learned knowledge
  • Later, you can use this saved data to recreate the model exactly as it was
  • The recreated model will make the same predictions as the original

Think of it like taking a snapshot of your model's brain at a specific moment in time.

SetActiveFeatureIndices(IEnumerable<int>)

Sets the active feature indices for this model.

public virtual void SetActiveFeatureIndices(IEnumerable<int> featureIndices)

Parameters

featureIndices IEnumerable<int>

SetParameters(Vector<T>)

Sets the model parameters.

public abstract void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to set.

Remarks

This method allows direct modification of the model's internal parameters. This is useful for optimization algorithms that need to update parameters iteratively. If the length of parameters does not match ParameterCount, an ArgumentException should be thrown.

Exceptions

ArgumentException

Thrown when the length of parameters does not match ParameterCount.

SetSlicingEnabled(bool)

Enables or disables slicing mode.

public virtual void SetSlicingEnabled(bool enabled)

Parameters

enabled bool

Whether to enable slicing.

SetTilingEnabled(bool)

Enables or disables tiling mode.

public virtual void SetTilingEnabled(bool enabled)

Parameters

enabled bool

Whether to enable tiling.

Train(Tensor<T>, Tensor<T>)

Trains the model using input features and their corresponding target values.

public virtual void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>
expectedOutput Tensor<T>

Remarks

This method takes training data and adjusts the model's internal parameters to learn patterns in the data.

For Beginners: Training is like teaching the model by showing it examples.

Imagine teaching a child to identify fruits:

  • You show them many examples of apples, oranges, and bananas (input features x)
  • You tell them the correct name for each fruit (target values y)
  • Over time, they learn to recognize the patterns that distinguish each fruit

In machine learning:

  • The x parameter contains features (characteristics) of your data
  • The y parameter contains the correct answers you want the model to learn
  • During training, the model adjusts its internal calculations to get better at predicting y from x

For example, in a house price prediction model:

  • x would contain features like square footage, number of bedrooms, location
  • y would contain the actual sale prices of those houses

UnscaleLatent(Tensor<T>)

Unscales latent values before decoding (inverts LatentScaleFactor).

public virtual Tensor<T> UnscaleLatent(Tensor<T> latent)

Parameters

latent Tensor<T>

The scaled latent from diffusion.

Returns

Tensor<T>

Unscaled latent values ready for decoding.

Remarks

Divides by LatentScaleFactor to reverse the scaling before decoding.

WithParameters(Vector<T>)

Creates a new instance with the specified parameters.

public virtual IFullModel<T, Tensor<T>, Tensor<T>> WithParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

Returns

IFullModel<T, Tensor<T>, Tensor<T>>