Class DiTNoisePredictor<T>
- Namespace
- AiDotNet.Diffusion.NoisePredictors
- Assembly
- AiDotNet.dll
Diffusion Transformer (DiT) noise predictor for diffusion models.
public class DiTNoisePredictor<T> : NoisePredictorBase<T>, INoisePredictor<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DiTNoisePredictor<T>
- Implements
- Inherited Members
- Extension Methods
Examples
// Create DiT predictor for latent diffusion
var dit = new DiTNoisePredictor<float>(
inputChannels: 4, // Latent channels
hiddenSize: 1152, // DiT-XL/2 size
numLayers: 28, // DiT-XL depth
numHeads: 16,
patchSize: 2);
// Predict noise
var noisePrediction = dit.PredictNoise(noisyLatent, timestep, textEmbedding);
Remarks
DiT (Diffusion Transformer) replaces the traditional U-Net architecture with a pure transformer design. This approach leverages the scalability and effectiveness of transformers, enabling better performance at larger scales.
For Beginners: DiT is the "new generation" of noise prediction:
Traditional U-Net approach:
- Uses convolutional neural networks
- Has encoder-decoder structure with skip connections
- Good, but limited scalability
DiT approach (this class):
- Uses transformer architecture (like GPT, but for images)
- Treats image as patches (like words in a sentence)
- Scales better with more compute and data
- Powers cutting-edge models like DALL-E 3, Sora
Key advantages:
- Better quality at large scales
- Simpler architecture (no skip connections needed)
- More flexible conditioning mechanisms
- Easier to scale training
Architecture details: - Patchify: Split image into 2x2 or larger patches - Position embedding: Add spatial information - Transformer blocks: Self-attention + MLP - AdaLN: Adaptive layer normalization for timestep/conditioning - Unpatchify: Reconstruct full resolution output
Used in: DiT (original), DALL-E 3, Sora, SD3, Pixart-alpha
Constructors
DiTNoisePredictor()
Initializes a new DiT noise predictor with default XL/2 parameters.
public DiTNoisePredictor()
DiTNoisePredictor(int, int, int, int, int, int, double, int, int?)
Initializes a new DiT noise predictor with custom parameters.
public DiTNoisePredictor(int inputChannels = 4, int hiddenSize = 1152, int numLayers = 28, int numHeads = 16, int patchSize = 2, int contextDim = 1024, double mlpRatio = 4, int numClasses = 0, int? seed = null)
Parameters
inputChannelsintNumber of input channels.
hiddenSizeintHidden dimension size.
numLayersintNumber of transformer layers.
numHeadsintNumber of attention heads.
patchSizeintPatch size for tokenization.
contextDimintConditioning context dimension.
mlpRatiodoubleMLP hidden dimension ratio.
numClassesintNumber of classes for class conditioning (0 for text-only).
seedint?Random seed for initialization.
Properties
BaseChannels
Gets the base channel count used in the network architecture.
public override int BaseChannels { get; }
Property Value
Remarks
This determines the model capacity. Common values: - 320 for Stable Diffusion 1.x and 2.x - 384 for Stable Diffusion XL (base) - 1024 for large DiT models
ContextDimension
Gets the expected context dimension for cross-attention conditioning.
public override int ContextDimension { get; }
Property Value
Remarks
For CLIP-conditioned models, this is typically 768 or 1024. For T5-conditioned models (like SD3), this is typically 2048. Returns 0 if cross-attention is not supported.
HiddenSize
Gets the hidden size.
public int HiddenSize { get; }
Property Value
InputChannels
Gets the number of input channels the predictor expects.
public override int InputChannels { get; }
Property Value
Remarks
For image models, this is typically: - 4 for latent diffusion models (VAE latent channels) - 3 for pixel-space RGB models - Higher for models with additional conditioning channels
NumLayers
Gets the number of layers.
public int NumLayers { get; }
Property Value
OutputChannels
Gets the number of output channels the predictor produces.
public override int OutputChannels { get; }
Property Value
Remarks
Usually matches InputChannels since we predict noise of the same shape as input. Some architectures may predict additional outputs like variance.
ParameterCount
Gets the number of parameters in the model.
public override int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
PatchSize
Gets the patch size.
public int PatchSize { get; }
Property Value
SupportsCFG
Gets whether this noise predictor supports classifier-free guidance.
public override bool SupportsCFG { get; }
Property Value
Remarks
Classifier-free guidance allows steering generation toward the conditioning (e.g., text prompt) without a separate classifier. Most modern models support this.
SupportsCrossAttention
Gets whether this noise predictor supports cross-attention conditioning.
public override bool SupportsCrossAttention { get; }
Property Value
Remarks
Cross-attention allows the model to attend to conditioning tokens (like text embeddings). This is how text-to-image models incorporate the prompt.
TimeEmbeddingDim
Gets the dimension of the time/timestep embedding.
public override int TimeEmbeddingDim { get; }
Property Value
Remarks
The timestep is embedded into a high-dimensional vector before being injected into the network. Typical values: 256, 512, 1024.
Methods
Clone()
Creates a deep copy of the noise predictor.
public override INoisePredictor<T> Clone()
Returns
- INoisePredictor<T>
A new instance with the same parameters.
DeepCopy()
Creates a deep copy of this object.
public override IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
GetParameters()
Gets the parameters that can be optimized.
public override Vector<T> GetParameters()
Returns
- Vector<T>
PredictNoise(Tensor<T>, int, Tensor<T>?)
Predicts the noise in a noisy sample at a given timestep.
public override Tensor<T> PredictNoise(Tensor<T> noisySample, int timestep, Tensor<T>? conditioning = null)
Parameters
noisySampleTensor<T>The noisy input sample [batch, channels, height, width].
timestepintThe current timestep in the diffusion process.
conditioningTensor<T>Optional conditioning tensor (e.g., text embeddings).
Returns
- Tensor<T>
The predicted noise tensor with the same shape as noisySample.
Remarks
This is the main forward pass of the noise predictor. Given a noisy sample at timestep t, it predicts what noise was added.
For Beginners: This is where the actual denoising happens: 1. The network looks at the noisy image 2. It considers how noisy it should be at this timestep 3. It predicts the noise pattern 4. This prediction is subtracted to get a cleaner image
PredictNoiseWithEmbedding(Tensor<T>, Tensor<T>, Tensor<T>?)
Predicts noise with explicit timestep embedding (for batched different timesteps).
public override Tensor<T> PredictNoiseWithEmbedding(Tensor<T> noisySample, Tensor<T> timeEmbedding, Tensor<T>? conditioning = null)
Parameters
noisySampleTensor<T>The noisy input sample [batch, channels, height, width].
timeEmbeddingTensor<T>Pre-computed timestep embeddings [batch, timeEmbeddingDim].
conditioningTensor<T>Optional conditioning tensor (e.g., text embeddings).
Returns
- Tensor<T>
The predicted noise tensor with the same shape as noisySample.
Remarks
This overload is useful when you want to use different timesteps per sample in a batch, or when you have pre-computed timestep embeddings for efficiency.
SetParameters(Vector<T>)
Sets the model parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to set.
Remarks
This method allows direct modification of the model's internal parameters.
This is useful for optimization algorithms that need to update parameters iteratively.
If the length of parameters does not match ParameterCount,
an ArgumentException should be thrown.
Exceptions
- ArgumentException
Thrown when the length of
parametersdoes not match ParameterCount.