Class TemporalFusionTransformer<T>
- Namespace
- AiDotNet.TimeSeries
- Assembly
- AiDotNet.dll
Implements the Temporal Fusion Transformer (TFT) for interpretable multi-horizon forecasting.
public class TemporalFusionTransformer<T> : TimeSeriesModelBase<T>, ITimeSeriesModel<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
- Inheritance
-
TemporalFusionTransformer<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Temporal Fusion Transformer is a state-of-the-art attention-based architecture that combines high-performance multi-horizon forecasting with interpretable insights. Key features include:
- Multi-horizon probabilistic forecasts with quantile predictions
- Variable selection networks for interpretability
- Multi-head self-attention mechanisms for learning temporal relationships
- Handling of static metadata, known future inputs, and unknown past inputs
- Gating mechanisms for skip connections and variable selection
Original paper: Lim et al., "Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting" (2021).
Production-Ready Features:
- Uses Tensor<T> for GPU-accelerated operations via IEngine
- Proper multi-head self-attention with Q, K, V projections
- Full backpropagation through all layers (no numerical differentiation)
- All parameters are trained (not subsets)
- Vectorized operations where possible
For Beginners: TFT is an advanced neural network that excels at forecasting multiple time steps ahead while providing insights into what drives the predictions. It can handle: - Multiple related time series - Various types of features (static, known future, unknown past) - Uncertainty quantification through probabilistic forecasts
The attention mechanism allows the model to "focus" on the most relevant historical periods when making predictions, similar to how a human analyst would examine past trends.
Constructors
TemporalFusionTransformer(TemporalFusionTransformerOptions<T>?)
Initializes a new instance of the TemporalFusionTransformer class.
public TemporalFusionTransformer(TemporalFusionTransformerOptions<T>? options = null)
Parameters
optionsTemporalFusionTransformerOptions<T>Configuration options for the TFT model.
Properties
ParameterCount
Gets the number of parameters in the model.
public override int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
Methods
CreateInstance()
Creates a new instance of the derived model class.
protected override IFullModel<T, Matrix<T>, Vector<T>> CreateInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the same model type.
Remarks
This abstract factory method must be implemented by derived classes to create a new instance of their specific type. It's used by Clone and DeepCopy to ensure that the correct derived type is instantiated.
For Beginners: This method creates a new, empty instance of the specific model type. It's used during cloning and deep copying to ensure that the copy is of the same specific type as the original.
For example, if the original model is an ARIMA model, this method would create a new ARIMA model. If it's a TBATS model, it would create a new TBATS model.
DeserializeCore(BinaryReader)
Deserializes model-specific data from the binary reader.
protected override void DeserializeCore(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
Remarks
This abstract method must be implemented by each specific model type to load its unique parameters and state.
For Beginners: This method is responsible for loading the specific details that make each type of time series model unique. It reads exactly what was written by SerializeCore, in the same order, reconstructing the specialized parts of the model.
It's the counterpart to SerializeCore and should read data in exactly the same order and format that it was written.
This separation allows the base class to handle common deserialization tasks while each model type handles its specialized data.
ForecastWithQuantiles(Vector<T>)
Forecasts multiple quantiles for the full horizon.
public Dictionary<double, Vector<T>> ForecastWithQuantiles(Vector<T> history)
Parameters
historyVector<T>
Returns
- Dictionary<double, Vector<T>>
GetModelMetadata()
Gets metadata about the time series model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
Remarks
This method provides comprehensive metadata about the model, including its type, configuration options, training status, evaluation metrics, and information about which features/lags are most important.
For Beginners: This method provides important information about the model that can help you understand its characteristics and performance.
The metadata includes:
- The type of model (e.g., ARIMA, TBATS, Neural Network)
- Configuration details (e.g., lag order, seasonality period)
- Whether the model has been trained
- Performance metrics from the last evaluation
- Information about which features (time periods) are most influential
This information is useful for documentation, model comparison, and debugging. It's like a complete summary of everything important about the model.
PredictQuantiles(Vector<T>)
Predicts quantiles for all forecast horizons.
public Vector<T> PredictQuantiles(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
PredictSingle(Vector<T>)
Predicts a single value (median quantile, first horizon step).
public override T PredictSingle(Vector<T> input)
Parameters
inputVector<T>
Returns
- T
SerializeCore(BinaryWriter)
Serializes model-specific data to the binary writer.
protected override void SerializeCore(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
Remarks
This abstract method must be implemented by each specific model type to save its unique parameters and state.
For Beginners: This method is responsible for saving the specific details that make each type of time series model unique. Different models have different internal structures and parameters that need to be saved separately from the common elements.
For example:
- An ARIMA model would save its AR, I, and MA coefficients
- A TBATS model would save its level, trend, and seasonal components
- A neural network model would save its weights and biases
This separation allows the base class to handle common serialization tasks while each model type handles its specialized data.
TrainCore(Matrix<T>, Vector<T>)
Performs the core training logic using proper backpropagation.
protected override void TrainCore(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>yVector<T>