Class TemporalVAE<T>
Temporal-aware Variational Autoencoder for video diffusion models.
public class TemporalVAE<T> : VAEModelBase<T>, IVAEModel<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
VAEModelBase<T>TemporalVAE<T>
- Implements
-
IVAEModel<T>
- Inherited Members
- Extension Methods
Remarks
The TemporalVAE extends the standard VAE to handle video data by incorporating temporal awareness into the encoding and decoding process. This helps maintain temporal consistency across frames when used in video diffusion models.
For Beginners: While a standard VAE processes each frame independently, TemporalVAE considers relationships between consecutive frames:
Standard VAE approach (per-frame):
- Frame 1 -> Latent 1 (no knowledge of other frames)
- Frame 2 -> Latent 2 (no knowledge of other frames)
- Result: Possible flickering/inconsistency between frames
TemporalVAE approach:
- Frames 1,2,3,... -> Encode with temporal awareness
- Latent knows about neighboring frames
- Result: Smoother, more consistent video
Key features:
- 3D convolutions that span across time dimension
- Temporal attention for long-range frame relationships
- Optional causal mode for streaming/autoregressive generation
Used in: Stable Video Diffusion, Video LDM, and similar models.
Architecture details: - Input: [batch, channels, frames, height, width] video tensor - Encoder: 2D spatial blocks + 1D temporal blocks - Latent: [batch, latentChannels, frames, height/8, width/8] - Decoder: 2D spatial blocks + 1D temporal blocks - Output: [batch, channels, frames, height, width] reconstructed video
Constructors
TemporalVAE(int, int, int, int[]?, int, int, bool, double?, ILossFunction<T>?, int?)
Initializes a new instance of the TemporalVAE class.
public TemporalVAE(int inputChannels = 3, int latentChannels = 4, int baseChannels = 128, int[]? channelMultipliers = null, int numTemporalLayers = 1, int temporalKernelSize = 3, bool causalMode = false, double? latentScaleFactor = null, ILossFunction<T>? lossFunction = null, int? seed = null)
Parameters
inputChannelsintNumber of input image channels (default: 3 for RGB).
latentChannelsintNumber of latent channels (default: 4).
baseChannelsintBase channel count (default: 128).
channelMultipliersint[]Channel multipliers per level (default: [1, 2, 4, 4]).
numTemporalLayersintNumber of temporal layers per spatial block (default: 1).
temporalKernelSizeintKernel size for temporal convolutions (default: 3).
causalModeboolWhether to use causal convolutions (default: false).
latentScaleFactordouble?Scale factor for latents (default: 0.18215).
lossFunctionILossFunction<T>Optional loss function (default: MSE).
seedint?Optional random seed for reproducibility.
Properties
DownsampleFactor
Gets the spatial downsampling factor.
public override int DownsampleFactor { get; }
Property Value
Remarks
The factor by which the VAE reduces spatial dimensions. Stable Diffusion uses 8x downsampling, so a 512x512 image becomes 64x64 latents.
InputChannels
Gets the number of input channels (image channels).
public override int InputChannels { get; }
Property Value
Remarks
Typically 3 for RGB images. Could be 1 for grayscale or 4 for RGBA.
IsCausal
Gets whether this VAE uses causal convolutions.
public bool IsCausal { get; }
Property Value
LatentChannels
Gets the number of latent channels.
public override int LatentChannels { get; }
Property Value
Remarks
Standard Stable Diffusion VAEs use 4 latent channels. Some newer VAEs may use different values (e.g., 16 for certain architectures).
LatentScaleFactor
Gets the scale factor for latent values.
public override double LatentScaleFactor { get; }
Property Value
Remarks
A normalization factor applied to latent values. For Stable Diffusion, this is 0.18215, which normalizes the latent distribution to unit variance.
ParameterCount
Gets the number of parameters in the model.
public override int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
SupportsSlicing
Gets whether this VAE uses slicing for sequential processing.
public override bool SupportsSlicing { get; }
Property Value
Remarks
Slicing processes the batch one sample at a time to reduce memory. Trades speed for memory efficiency.
SupportsTiling
Gets whether this VAE uses tiling for memory-efficient encoding/decoding.
public override bool SupportsTiling { get; }
Property Value
Remarks
Tiling processes the image in overlapping patches to reduce memory usage when handling large images. Useful for high-resolution generation.
TemporalKernelSize
Gets the temporal kernel size.
public int TemporalKernelSize { get; }
Property Value
Methods
Clone()
Creates a deep copy of the VAE model.
public override IVAEModel<T> Clone()
Returns
- IVAEModel<T>
A new instance with the same parameters.
Decode(Tensor<T>)
Decodes a latent representation back to image space.
public override Tensor<T> Decode(Tensor<T> latent)
Parameters
latentTensor<T>The latent tensor [batch, latentChannels, latentHeight, latentWidth].
Returns
- Tensor<T>
The decoded image [batch, channels, heightdownFactor, widthdownFactor].
Remarks
For Beginners: This decompresses the latent back to an image: - Input: Small latent (64x64x4) - Output: Full-size image (512x512x3) - The image looks like the original but with minor differences due to compression
DecodeVideoFromDiffusion(Tensor<T>)
Decodes a diffusion video latent back to video space.
public Tensor<T> DecodeVideoFromDiffusion(Tensor<T> latent)
Parameters
latentTensor<T>The latent from diffusion (already scaled).
Returns
- Tensor<T>
The decoded video.
DeepCopy()
Creates a deep copy of this object.
public override IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
Encode(Tensor<T>, bool)
Encodes an image into the latent space.
public override Tensor<T> Encode(Tensor<T> video, bool sampleMode = true)
Parameters
videoTensor<T>sampleModeboolIf true, samples from the latent distribution. If false, returns the mean.
Returns
- Tensor<T>
The latent representation [batch, latentChannels, height/downFactor, width/downFactor].
Remarks
The VAE encoder outputs a distribution (mean and log variance). When sampleMode is true, we sample from this distribution using the reparameterization trick. When false, we just return the mean for deterministic encoding.
For Beginners: This compresses the image: - Input: Full-size image (512x512x3) - Output: Small latent representation (64x64x4) - The latent contains all the important information in a compressed form
EncodeVideoForDiffusion(Tensor<T>, bool)
Encodes a video and applies latent scaling for use in diffusion.
public Tensor<T> EncodeVideoForDiffusion(Tensor<T> video, bool sampleMode = true)
Parameters
videoTensor<T>The input video tensor.
sampleModeboolWhether to sample from the distribution.
Returns
- Tensor<T>
Scaled video latent representation.
EncodeWithDistribution(Tensor<T>)
Encodes and returns both mean and log variance (for training).
public override (Tensor<T> Mean, Tensor<T> LogVariance) EncodeWithDistribution(Tensor<T> video)
Parameters
videoTensor<T>
Returns
Remarks
Used during VAE training where we need both the mean and variance for computing the KL divergence loss.
GetParameters()
Gets the parameters that can be optimized.
public override Vector<T> GetParameters()
Returns
- Vector<T>
SetParameters(Vector<T>)
Sets the model parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to set.
Remarks
This method allows direct modification of the model's internal parameters.
This is useful for optimization algorithms that need to update parameters iteratively.
If the length of parameters does not match ParameterCount,
an ArgumentException should be thrown.
Exceptions
- ArgumentException
Thrown when the length of
parametersdoes not match ParameterCount.