Class NBEATSModel<T>
- Namespace
- AiDotNet.TimeSeries
- Assembly
- AiDotNet.dll
Implements the N-BEATS (Neural Basis Expansion Analysis for Time Series) model for forecasting.
public class NBEATSModel<T> : TimeSeriesModelBase<T>, ITimeSeriesModel<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
- Inheritance
-
NBEATSModel<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
N-BEATS is a deep neural architecture based on backward and forward residual links and a very deep stack of fully-connected layers. The architecture has the following key features:
- Doubly residual stacking: Each block produces a backcast (reconstruction) and forecast
- Hierarchical decomposition: Multiple stacks focus on different aspects (trend, seasonality)
- Interpretability: Can use polynomial and Fourier basis for explainable forecasts
- No manual feature engineering: Learns directly from raw time series data
The original paper: Oreshkin et al., "N-BEATS: Neural basis expansion analysis for interpretable time series forecasting" (ICLR 2020).
For Beginners: N-BEATS is a state-of-the-art neural network for time series forecasting that automatically learns patterns from your data. Unlike traditional methods that require you to manually specify trends and seasonality, N-BEATS figures these out on its own.
Key advantages:
- No need for manual feature engineering (the model learns what's important)
- Can capture complex, non-linear patterns
- Provides interpretable components (trend, seasonality) when configured to do so
- Works well for both short-term and long-term forecasting
The model works by stacking many "blocks" together, where each block tries to:
- Understand what patterns are in the input (backcast)
- Predict the future based on those patterns (forecast)
- Pass the unexplained patterns to the next block
This allows the model to decompose complex time series into simpler components.
Constructors
NBEATSModel(NBEATSModelOptions<T>?)
Initializes a new instance of the NBEATSModel class.
public NBEATSModel(NBEATSModelOptions<T>? options = null)
Parameters
optionsNBEATSModelOptions<T>Configuration options for the N-BEATS model. If null, default options are used.
Remarks
For Beginners: This creates a new N-BEATS model with the specified configuration. The options control things like: - How far back to look (lookback window) - How far forward to predict (forecast horizon) - How complex the model should be (number of stacks, blocks, layer sizes) - Whether to use interpretable components
If you don't provide options, sensible defaults will be used.
Properties
ParameterCount
Gets the total number of trainable parameters in the model.
public override int ParameterCount { get; }
Property Value
SupportsJitCompilation
Gets whether this model supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
Returns
truewhen the model has been trained and has initialized blocks. N-BEATS architecture can be represented as a computation graph with the doubly-residual stacking pattern, enabling JIT compilation for optimized inference.
Remarks
For Beginners: JIT (Just-In-Time) compilation converts the model's calculations into optimized native code that runs much faster. N-BEATS can be JIT compiled because its forward pass can be expressed as a series of matrix operations with residual connections.
Methods
CreateInstance()
Creates a new instance of the N-BEATS model.
protected override IFullModel<T, Matrix<T>, Vector<T>> CreateInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new N-BEATS model instance with the same configuration.
Remarks
Creates a deep copy of the model options to ensure the cloned model has an independent options instance.
DeserializeCore(BinaryReader)
Deserializes model-specific data from the binary reader.
protected override void DeserializeCore(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the N-BEATS model as a computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>A list to which input nodes will be added.
Returns
- ComputationNode<T>
The output computation node representing the forecast.
Remarks
The computation graph represents the N-BEATS forward pass: 1. For each block, compute backcast and forecast from the current residual 2. Update residual: residual = residual - backcast 3. Accumulate forecast: total_forecast = total_forecast + block_forecast 4. Return the first element of the aggregated forecast
For Beginners: This converts the entire N-BEATS model into a computation graph that can be optimized by the JIT compiler. The graph chains all blocks together with their residual connections, allowing the JIT compiler to: - Fuse operations across blocks - Optimize memory usage - Generate fast native code
Expected speedup: 3-5x for inference after JIT compilation.
ForecastHorizon(Vector<T>)
Generates forecasts for multiple future time steps.
public Vector<T> ForecastHorizon(Vector<T> input)
Parameters
inputVector<T>The input vector containing the lookback window of historical values.
Returns
- Vector<T>
A vector of forecasted values for all forecast horizon steps.
Remarks
For Beginners: This method predicts multiple future time steps at once. Unlike PredictSingle which only returns the next value, this returns all values up to the forecast horizon.
For example, if your forecast horizon is 7, this will predict the next 7 time steps.
GetModelMetadata()
Gets metadata about the N-BEATS model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing information about the model.
GetParameters()
Gets all model parameters as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all trainable parameters from all blocks.
PredictSingle(Vector<T>)
Predicts a single value based on the provided input vector.
public override T PredictSingle(Vector<T> input)
Parameters
inputVector<T>The input vector containing the lookback window of historical values.
Returns
- T
The predicted value for the next time step.
Remarks
For Beginners: This method takes a window of historical values and predicts the next value. It runs the input through all the blocks in the model, each block contributing to the final prediction.
SerializeCore(BinaryWriter)
Serializes model-specific data to the binary writer.
protected override void SerializeCore(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
SetParameters(Vector<T>)
Sets all model parameters from a single vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all trainable parameters.
TrainCore(Matrix<T>, Vector<T>)
Performs the core training logic for the N-BEATS model.
protected override void TrainCore(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>The input features matrix where each row is a historical window.
yVector<T>The target values vector where each element is the corresponding forecast target.
Remarks
Training uses a simple gradient descent approach with mean squared error loss. The model iterates through the training data for the specified number of epochs, updating parameters to minimize prediction error.
For Beginners: This is where the model actually learns from your data.
The training process:
- The model makes predictions on your training data
- It calculates how far off the predictions are (the error)
- It adjusts its internal parameters to reduce this error
- It repeats this process many times (epochs) until it learns the patterns
Note: This is a simplified training implementation. A production version would include more sophisticated optimization, regularization, and validation.