Class NoiseSchedulerBase<T>
- Namespace
- AiDotNet.NeuralNetworks.Diffusion.Schedulers
- Assembly
- AiDotNet.dll
Base class for diffusion model noise schedulers providing common functionality.
public abstract class NoiseSchedulerBase<T> : INoiseScheduler<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
NoiseSchedulerBase<T>
- Implements
- Derived
- Inherited Members
Remarks
This abstract base class implements the common behavior for all noise schedulers, including beta schedule computation, alpha cumulative product calculation, noise addition, and state management for checkpointing.
Note: This class was renamed from StepSchedulerBase to NoiseSchedulerBase to avoid confusion with learning rate schedulers. Noise schedulers are specific to diffusion models.
For Beginners: This is the foundation that all noise schedulers build upon. It handles the common math and state management that every scheduler needs: - Computing the noise schedule (how much noise at each step) - Tracking the current state for saving/loading - Adding noise during training
Specific schedulers like DDIM, PNDM, and DPM-Solver extend this base to implement their unique denoising strategies.
Constructors
NoiseSchedulerBase(SchedulerConfig<T>)
Initializes a new instance of the NoiseSchedulerBase class.
protected NoiseSchedulerBase(SchedulerConfig<T> config)
Parameters
configSchedulerConfig<T>Configuration for the scheduler including beta schedule parameters.
Exceptions
- ArgumentNullException
Thrown when config is null.
- NotSupportedException
Thrown when an unsupported beta schedule is specified.
Fields
Alphas
Alpha values (1 - beta) representing signal retention at each timestep.
protected Vector<T> Alphas
Field Value
- Vector<T>
AlphasCumulativeProduct
Cumulative product of alphas representing total signal retention at each timestep.
protected Vector<T> AlphasCumulativeProduct
Field Value
- Vector<T>
Remarks
This is the key value for diffusion: alpha_cumprod[t] tells you what fraction of the original signal remains at timestep t. At t=0 it's ~1, at t=T it's ~0.
Betas
Beta values (noise variance) at each training timestep.
protected Vector<T> Betas
Field Value
- Vector<T>
Remarks
Beta represents how much noise is added at each step. Higher beta = more noise. These values typically increase from a small start value to a larger end value.
NumOps
Provides numeric operations for the specific type T.
protected static readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Remarks
For Beginners: This is a helper that knows how to do math with your specific number type, whether that's float, double, or decimal.
Properties
Config
Gets the configuration options for the scheduler.
public SchedulerConfig<T> Config { get; }
Property Value
Remarks
The configuration contains the parameters used to create and initialize the scheduler, such as the beta schedule, prediction type, and training timesteps.
Engine
Gets the compute engine for GPU-accelerated vectorized operations.
protected IEngine Engine { get; }
Property Value
- IEngine
Timesteps
Gets the timesteps for the current inference schedule.
public int[] Timesteps { get; }
Property Value
- int[]
Remarks
These are the discrete time indices at which denoising steps will be performed. The array is typically in descending order (from highest noise to lowest).
For Beginners: This is like a list of checkpoints. If you have 50 inference steps, this array tells you exactly which of the original 1000 training timesteps to use for denoising.
TrainTimesteps
Gets the number of training timesteps this scheduler was configured with.
public int TrainTimesteps { get; }
Property Value
Remarks
This is the total number of timesteps used during training, typically 1000. The scheduler interpolates between these for inference.
Methods
AddNoise(Vector<T>, Vector<T>, int)
Adds noise to a clean sample according to the noise schedule.
public virtual Vector<T> AddNoise(Vector<T> originalSample, Vector<T> noise, int timestep)
Parameters
originalSampleVector<T>The clean sample to add noise to.
noiseVector<T>The noise to add.
timestepintThe timestep determining how much noise to add.
Returns
- Vector<T>
The noisy sample.
Remarks
This implements the forward diffusion process: q(x_t | x_0) = sqrt(alpha_cumprod) * x_0 + sqrt(1 - alpha_cumprod) * noise
For Beginners: This is like adding a specific amount of static to a clear image. Higher timesteps add more noise. This is used during training to create noisy samples for the model to learn from.
ClipSampleIfNeeded(Vector<T>)
Clips sample values to [-1, 1] if configured.
protected Vector<T> ClipSampleIfNeeded(Vector<T> sample)
Parameters
sampleVector<T>The sample to potentially clip.
Returns
- Vector<T>
The clipped sample if ClipSample is true, otherwise the original sample.
GetAlphaCumulativeProduct(int)
Gets the cumulative product of alphas (signal retention) at a given timestep.
public virtual T GetAlphaCumulativeProduct(int timestep)
Parameters
timestepintThe timestep to query.
Returns
- T
The cumulative alpha value at that timestep.
Remarks
Alpha cumulative product represents how much of the original signal is retained at each timestep. At t=0, it's close to 1 (mostly signal). At t=T, it's close to 0 (mostly noise).
For Beginners: This tells you "how clear" the image is at each step. At the start (t=0), the image is clear (alpha near 1). At the end (t=T), it's pure noise (alpha near 0).
GetState()
Gets the current scheduler state for checkpointing.
public virtual Dictionary<string, object> GetState()
Returns
- Dictionary<string, object>
A dictionary containing the scheduler's state.
LoadState(Dictionary<string, object>)
Loads scheduler state from a checkpoint.
public virtual void LoadState(Dictionary<string, object> state)
Parameters
stateDictionary<string, object>The state dictionary to load from.
SetTimesteps(int)
Sets up the inference timesteps based on the number of steps desired.
public virtual void SetTimesteps(int inferenceSteps)
Parameters
inferenceStepsintNumber of denoising steps to use during inference.
Remarks
This method calculates which timesteps from the training schedule should be used for the given number of inference steps. Using fewer steps is faster but may reduce quality.
For Beginners: This is like choosing how many steps to take when walking from point A to point B. More steps (50-100) give smoother results, fewer steps (10-20) are faster but may miss details.
Exceptions
- ArgumentOutOfRangeException
Thrown when inferenceSteps is less than 1 or greater than TrainTimesteps.
Step(Vector<T>, int, Vector<T>, T, Vector<T>?)
Performs one denoising step using the model output.
public abstract Vector<T> Step(Vector<T> modelOutput, int timestep, Vector<T> sample, T eta, Vector<T>? noise = null)
Parameters
modelOutputVector<T>The model's prediction (typically noise prediction).
timestepintThe current timestep in the diffusion process.
sampleVector<T>The current noisy sample.
etaTStochasticity parameter: 0 = deterministic (DDIM), 1 = fully stochastic (DDPM). Values between 0 and 1 interpolate between these behaviors.
noiseVector<T>Optional noise for stochastic sampling. If null and eta > 0, uses zero noise (deterministic fallback).
Returns
- Vector<T>
The denoised sample for the previous timestep.
Remarks
This is the core denoising operation. Given the current noisy sample and the model's prediction of what noise was added, it computes a slightly less noisy version.
For Beginners: This is one step of "un-blurring" the image. The model looks at the current noisy image and guesses what noise is there. This method then removes that estimated noise to get a cleaner image.
The eta parameter controls randomness:
- eta=0: Always produces the same output for the same input (deterministic)
- eta=1: Adds randomness, making each generation unique (stochastic)
Exceptions
- ArgumentNullException
Thrown when modelOutput or sample is null.
- ArgumentException
Thrown when modelOutput and sample have different lengths.
- ArgumentOutOfRangeException
Thrown when timestep is negative or greater than TrainTimesteps.
ValidateStepParameters(Vector<T>, Vector<T>, int)
Validates common step parameters.
protected void ValidateStepParameters(Vector<T> modelOutput, Vector<T> sample, int timestep)
Parameters
modelOutputVector<T>The model output to validate.
sampleVector<T>The sample to validate.
timestepintThe timestep to validate.
Exceptions
- ArgumentNullException
Thrown when modelOutput or sample is null.
- ArgumentException
Thrown when lengths don't match.
- ArgumentOutOfRangeException
Thrown when timestep is out of range.