Table of Contents

Class VAEEncoder<T>

Namespace
AiDotNet.Diffusion.VAE
Assembly
AiDotNet.dll

Convolutional encoder for VAE that compresses images to latent space.

public class VAEEncoder<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
VAEEncoder<T>
Implements
Inherited Members

Remarks

This implements the encoder portion of a VAE following the Stable Diffusion architecture: - Input convolution to initial feature channels - Multiple DownBlocks with ResBlocks and strided conv downsampling - Middle blocks with attention at the bottleneck - Final convolutions to produce mean and log variance for the latent distribution

For Beginners: The VAE encoder is like an intelligent image compressor.

What it does step by step:

  1. Takes a high-resolution image (e.g., 512x512x3 RGB)
  2. Initial conv: Expands channels (3 -> 128) at full resolution
  3. DownBlocks: Progressively halves resolution while increasing channels
    • Block 1: 128 channels, 512x512 -> 256x256
    • Block 2: 256 channels, 256x256 -> 128x128
    • Block 3: 512 channels, 128x128 -> 64x64
    • Block 4: 512 channels, 64x64 -> 64x64 (no downsample at end)
  4. Middle: Extra processing at the bottleneck
  5. Output: Produces mean and log-variance for 4-channel latent

The result is a 64x64x4 latent that captures the image's essence in a compressed form suitable for diffusion.

Constructors

VAEEncoder(int, int, int, int[]?, int, int, int)

Initializes a new instance of the VAEEncoder class.

public VAEEncoder(int inputChannels = 3, int latentChannels = 4, int baseChannels = 128, int[]? channelMults = null, int numResBlocks = 2, int numGroups = 32, int inputSpatialSize = 512)

Parameters

inputChannels int

Number of input image channels (default: 3 for RGB).

latentChannels int

Number of latent channels (default: 4).

baseChannels int

Base channel count (default: 128).

channelMults int[]

Channel multipliers per level (default: [1, 2, 4, 4]).

numResBlocks int

Number of residual blocks per DownBlock (default: 2).

numGroups int

Number of groups for GroupNorm (default: 32).

inputSpatialSize int

Spatial size of input images (default: 512).

Properties

DownsampleFactor

Gets the downsampling factor (spatial reduction from input to output).

public int DownsampleFactor { get; }

Property Value

int

InputChannels

Gets the number of input channels.

public int InputChannels { get; }

Property Value

int

LatentChannels

Gets the number of latent channels.

public int LatentChannels { get; }

Property Value

int

NamedParameterCount

Gets the total number of named parameters.

public override int NamedParameterCount { get; }

Property Value

int

SupportsJitCompilation

Gets whether this layer supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

True if the layer can be JIT compiled, false otherwise.

Remarks

This property indicates whether the layer has implemented ExportComputationGraph() and can benefit from JIT compilation. All layers MUST implement this property.

For Beginners: JIT compilation can make inference 5-10x faster by converting the layer's operations into optimized native code.

Layers should return false if they:

  • Have not yet implemented a working ExportComputationGraph()
  • Use dynamic operations that change based on input data
  • Are too simple to benefit from JIT compilation

When false, the layer will use the standard Forward() method instead.

SupportsTraining

Gets a value indicating whether this layer supports training.

public override bool SupportsTraining { get; }

Property Value

bool

true if the layer has trainable parameters and supports backpropagation; otherwise, false.

Remarks

This property indicates whether the layer can be trained through backpropagation. Layers with trainable parameters such as weights and biases typically return true, while layers that only perform fixed transformations (like pooling or activation layers) typically return false.

For Beginners: This property tells you if the layer can learn from data.

A value of true means:

  • The layer has parameters that can be adjusted during training
  • It will improve its performance as it sees more data
  • It participates in the learning process

A value of false means:

  • The layer doesn't have any adjustable parameters
  • It performs the same operation regardless of training
  • It doesn't need to learn (but may still be useful)

Methods

Backward(Tensor<T>)

Performs the backward pass through the encoder.

public override Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Returns

Tensor<T>

BuildParameterRegistryPublic()

Builds and returns the parameter registry for external use.

public ParameterRegistry<T> BuildParameterRegistryPublic()

Returns

ParameterRegistry<T>

Deserialize(BinaryReader)

Loads the encoder's state from a binary reader.

public override void Deserialize(BinaryReader reader)

Parameters

reader BinaryReader

EncodeAndSample(Tensor<T>, int?)

Encodes an image and samples from the latent distribution.

public Tensor<T> EncodeAndSample(Tensor<T> input, int? seed = null)

Parameters

input Tensor<T>

Input image tensor.

seed int?

Optional random seed for reproducibility.

Returns

Tensor<T>

Sampled latent tensor.

EncodeWithDistribution(Tensor<T>)

Encodes and returns mean and log variance separately.

public (Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> input)

Parameters

input Tensor<T>

Input image tensor.

Returns

(Tensor<T> grad1, Tensor<T> grad2)

Tuple of (mean, logVariance) tensors.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the layer's computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the layer's operation.

Remarks

This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.

For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.

To support JIT compilation, a layer must:

  1. Implement this method to export its computation graph
  2. Set SupportsJitCompilation to true
  3. Use ComputationNode and TensorOperations to build the graph

All layers are required to implement this method, even if they set SupportsJitCompilation = false.

Forward(Tensor<T>)

Encodes an image to latent space, returning concatenated mean and log variance.

public override Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

Input image tensor [batch, inputChannels, H, W].

Returns

Tensor<T>

Concatenated mean and log variance [batch, 2*latentChannels, H/f, W/f].

GetParameterNames()

Gets all parameter names in this layer.

public override IEnumerable<string> GetParameterNames()

Returns

IEnumerable<string>

A collection of parameter names ("weight", "bias", or both depending on layer type).

Remarks

The default implementation returns "weight" and/or "bias" based on whether GetWeights() and GetBiases() return non-null values.

GetParameterShape(string)

Gets the expected shape for a parameter.

public override int[]? GetParameterShape(string name)

Parameters

name string

The parameter name ("weight" or "bias").

Returns

int[]

The expected shape, or null if the parameter doesn't exist.

GetParameters()

Gets all trainable parameters as a single vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

LoadWeights(Dictionary<string, Tensor<T>>, Func<string, string?>?, bool)

Loads weights from a dictionary of tensors using optional name mapping.

public override WeightLoadResult LoadWeights(Dictionary<string, Tensor<T>> weights, Func<string, string?>? mapping = null, bool strict = false)

Parameters

weights Dictionary<string, Tensor<T>>

Dictionary of weight name to tensor.

mapping Func<string, string>

Optional function to map source names to target names.

strict bool

If true, fails when any mapped weight fails to load.

Returns

WeightLoadResult

Load result with statistics.

ResetState()

Resets the internal state of the encoder.

public override void ResetState()

Serialize(BinaryWriter)

Saves the encoder's state to a binary writer.

public override void Serialize(BinaryWriter writer)

Parameters

writer BinaryWriter

SetParameter(string, Tensor<T>)

Sets a parameter tensor by name.

public override bool SetParameter(string name, Tensor<T> value)

Parameters

name string

The parameter name ("weight" or "bias").

value Tensor<T>

The tensor value to set.

Returns

bool

True if the parameter was set successfully, false if the name was not found.

Exceptions

ArgumentException

Thrown when the tensor shape doesn't match expected shape.

SetParameters(Vector<T>)

Sets all trainable parameters from a single vector.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

TryGetParameter(string, out Tensor<T>?)

Tries to get a parameter tensor by name.

public override bool TryGetParameter(string name, out Tensor<T>? tensor)

Parameters

name string

The parameter name ("weight" or "bias").

tensor Tensor<T>

The parameter tensor if found.

Returns

bool

True if the parameter was found, false otherwise.

UpdateParameters(T)

Updates all learnable parameters using gradient descent.

public override void UpdateParameters(T learningRate)

Parameters

learningRate T

ValidateWeights(IEnumerable<string>, Func<string, string?>?)

Validates that a set of weight names can be loaded into this layer.

public override WeightLoadValidation ValidateWeights(IEnumerable<string> weightNames, Func<string, string?>? mapping = null)

Parameters

weightNames IEnumerable<string>

Names of weights to validate.

mapping Func<string, string>

Optional weight name mapping function.

Returns

WeightLoadValidation

Validation result with matched and unmatched names.