Table of Contents

Class AudioVAE<T>

Namespace
AiDotNet.Diffusion.VAE
Assembly
AiDotNet.dll

Variational Autoencoder for audio mel-spectrogram encoding and decoding.

public class AudioVAE<T> : VAEModelBase<T>, IVAEModel<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
AudioVAE<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Examples

// Create an AudioVAE
var audioVAE = new AudioVAE<float>(
    melChannels: 64,
    latentChannels: 8,
    baseChannels: 64);

// Encode a mel spectrogram
var melSpec = LoadMelSpectrogram("audio.wav"); // Shape: [1, 64, 256]
var latent = audioVAE.Encode(melSpec);         // Shape: [1, 8, 64]

// Decode back to mel spectrogram
var reconstructed = audioVAE.Decode(latent);   // Shape: [1, 64, 256]

Remarks

The AudioVAE encodes mel spectrograms into a compressed latent representation and decodes latents back to mel spectrograms. This is a key component of audio latent diffusion models like AudioLDM.

For Beginners: Audio cannot be directly processed by diffusion models because raw audio waveforms are very long (e.g., 10 seconds at 16kHz = 160,000 samples). Instead, we use this pipeline:

Audio -> Mel Spectrogram -> VAE Encode -> Latent -> Diffusion -> VAE Decode -> Mel -> Vocoder -> Audio

The AudioVAE handles the "Mel -> Latent" and "Latent -> Mel" steps.

What is a mel spectrogram?

  • A visual representation of sound
  • X-axis: time, Y-axis: frequency (mel scale), Color: intensity
  • Looks like an image, so we can use image-like networks!

Example dimensions:

  • Mel spectrogram: [1, 64, 256] = 1 channel, 64 mel bins, 256 time frames
  • Latent: [1, 8, 64] = 8 channels, 64 time frames (compressed)

Architecture: - Encoder: 1D convolutions with downsampling along time axis - Latent: Compressed representation with 8 channels - Decoder: 1D transposed convolutions to reconstruct spectrogram - Uses KL divergence for regularization

Constructors

AudioVAE()

Initializes a new AudioVAE with default parameters.

public AudioVAE()

AudioVAE(int, int, int, int[]?, int, ILossFunction<T>?, int?)

Initializes a new AudioVAE with custom parameters.

public AudioVAE(int melChannels = 64, int latentChannels = 8, int baseChannels = 64, int[]? channelMultipliers = null, int numResBlocks = 2, ILossFunction<T>? lossFunction = null, int? seed = null)

Parameters

melChannels int

Number of mel spectrogram channels.

latentChannels int

Number of latent channels.

baseChannels int

Base channel count for conv layers.

channelMultipliers int[]

Channel multipliers for each level.

numResBlocks int

Number of residual blocks per level.

lossFunction ILossFunction<T>

Optional custom loss function.

seed int?

Optional random seed.

Properties

DownsampleFactor

Gets the spatial downsampling factor.

public override int DownsampleFactor { get; }

Property Value

int

Remarks

The factor by which the VAE reduces spatial dimensions. Stable Diffusion uses 8x downsampling, so a 512x512 image becomes 64x64 latents.

InputChannels

Gets the number of input channels (image channels).

public override int InputChannels { get; }

Property Value

int

Remarks

Typically 3 for RGB images. Could be 1 for grayscale or 4 for RGBA.

LatentChannels

Gets the number of latent channels.

public override int LatentChannels { get; }

Property Value

int

Remarks

Standard Stable Diffusion VAEs use 4 latent channels. Some newer VAEs may use different values (e.g., 16 for certain architectures).

LatentScaleFactor

Gets the scale factor for latent values.

public override double LatentScaleFactor { get; }

Property Value

double

Remarks

A normalization factor applied to latent values. For Stable Diffusion, this is 0.18215, which normalizes the latent distribution to unit variance.

MelChannels

Gets the number of mel channels.

public int MelChannels { get; }

Property Value

int

ParameterCount

Gets the number of parameters in the model.

public override int ParameterCount { get; }

Property Value

int

Remarks

This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.

TimeDownsampleFactor

Gets the time downsampling factor.

public int TimeDownsampleFactor { get; }

Property Value

int

Methods

AudioToMelSpectrogram(Tensor<T>, int, int, int)

Converts raw audio waveform to mel spectrogram.

public virtual Tensor<T> AudioToMelSpectrogram(Tensor<T> waveform, int sampleRate = 16000, int hopLength = 512, int fftSize = 2048)

Parameters

waveform Tensor<T>

Audio waveform tensor [batch, samples].

sampleRate int

Sample rate in Hz.

hopLength int

Hop length for STFT.

fftSize int

FFT window size.

Returns

Tensor<T>

Mel spectrogram tensor [batch, melChannels, timeFrames].

Remarks

For Beginners: This converts raw audio (like what comes out of a microphone) into a visual representation that captures both frequency and time:

Raw audio: [160000] samples (10 seconds at 16kHz) -> STFT (Short-Time Fourier Transform): frequency analysis in windows -> Mel filterbank: maps frequencies to perceptual mel scale -> Log: makes quiet and loud sounds more comparable = Mel spectrogram: [64, 256] (64 frequency bins, 256 time frames)

Clone()

Creates a deep copy of the VAE model.

public override IVAEModel<T> Clone()

Returns

IVAEModel<T>

A new instance with the same parameters.

Decode(Tensor<T>)

Decodes a latent representation back to image space.

public override Tensor<T> Decode(Tensor<T> latent)

Parameters

latent Tensor<T>

The latent tensor [batch, latentChannels, latentHeight, latentWidth].

Returns

Tensor<T>

The decoded image [batch, channels, heightdownFactor, widthdownFactor].

Remarks

For Beginners: This decompresses the latent back to an image: - Input: Small latent (64x64x4) - Output: Full-size image (512x512x3) - The image looks like the original but with minor differences due to compression

DeepCopy()

Creates a deep copy of this object.

public override IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

Encode(Tensor<T>, bool)

Encodes an image into the latent space.

public override Tensor<T> Encode(Tensor<T> input, bool sampleMode = true)

Parameters

input Tensor<T>
sampleMode bool

If true, samples from the latent distribution. If false, returns the mean.

Returns

Tensor<T>

The latent representation [batch, latentChannels, height/downFactor, width/downFactor].

Remarks

The VAE encoder outputs a distribution (mean and log variance). When sampleMode is true, we sample from this distribution using the reparameterization trick. When false, we just return the mean for deterministic encoding.

For Beginners: This compresses the image: - Input: Full-size image (512x512x3) - Output: Small latent representation (64x64x4) - The latent contains all the important information in a compressed form

EncodeWithDistribution(Tensor<T>)

Encodes and returns both mean and log variance (for training).

public override (Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> input)

Parameters

input Tensor<T>

Returns

(Tensor<T> grad1, Tensor<T> grad2)

Tuple of (mean, logVariance) tensors.

Remarks

Used during VAE training where we need both the mean and variance for computing the KL divergence loss.

GetParameters()

Gets the parameters that can be optimized.

public override Vector<T> GetParameters()

Returns

Vector<T>

MelSpectrogramToAudio(Tensor<T>, int, int)

Converts mel spectrogram back to audio waveform.

public virtual Tensor<T> MelSpectrogramToAudio(Tensor<T> melSpectrogram, int sampleRate = 16000, int hopLength = 512)

Parameters

melSpectrogram Tensor<T>

Mel spectrogram tensor [batch, melChannels, timeFrames].

sampleRate int

Sample rate in Hz.

hopLength int

Hop length used for spectrogram.

Returns

Tensor<T>

Audio waveform tensor [batch, samples].

Remarks

For Beginners: Converting from mel spectrogram back to audio is harder than going the other direction because:

  1. Mel spectrograms lose phase information
  2. The mel filterbank is not perfectly invertible

This method uses GPU-accelerated Griffin-Lim algorithm for phase reconstruction after inverting the mel spectrogram to a linear magnitude spectrogram.

SetParameters(Vector<T>)

Sets the model parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to set.

Remarks

This method allows direct modification of the model's internal parameters. This is useful for optimization algorithms that need to update parameters iteratively. If the length of parameters does not match ParameterCount, an ArgumentException should be thrown.

Exceptions

ArgumentException

Thrown when the length of parameters does not match ParameterCount.