Table of Contents

Class M5ModelTree<T>

Namespace
AiDotNet.Regression
Assembly
AiDotNet.dll

Represents an M5 model tree for regression problems, combining decision tree structure with linear models at the leaves.

public class M5ModelTree<T> : AsyncDecisionTreeRegressionBase<T>, IAsyncTreeBasedModel<T>, ITreeBasedRegression<T>, INonLinearRegression<T>, IRegression<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
M5ModelTree<T>
Implements
IFullModel<T, Matrix<T>, Vector<T>>
IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Matrix<T>, Vector<T>>
ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>
IGradientComputable<T, Matrix<T>, Vector<T>>
Inherited Members
Extension Methods

Remarks

The M5 model tree is an advanced regression technique that combines the benefits of decision trees and linear regression. Instead of using a single value at each leaf node (as in standard regression trees), M5 model trees fit linear regression models at each leaf. This allows the tree to capture both global patterns through its structure and local patterns through the linear models, often resulting in more accurate predictions compared to standard regression trees.

For Beginners: An M5 model tree is like a smart decision-making system for predicting numbers.

Think of it like a flowchart for home price prediction:

  • The tree asks questions about the home (Is it bigger than 2000 sq ft? Is it in neighborhood A?)
  • Based on the answers, you follow different paths down the tree
  • When you reach the end (a leaf), instead of getting a single price value, you get a mini-calculator (linear model)
  • This mini-calculator uses the home's features to make a more precise prediction for that specific group of homes

For example, for small homes in urban areas, the price might depend more on location, while for large homes in suburbs, the number of bathrooms might be more important. The M5 model tree captures these different patterns for different groups of data.

Constructors

M5ModelTree(M5ModelTreeOptions?, IRegularization<T, Matrix<T>, Vector<T>>?)

Initializes a new instance of the M5ModelTree<T> class with optional custom options and regularization.

public M5ModelTree(M5ModelTreeOptions? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)

Parameters

options M5ModelTreeOptions

Custom options for the M5 model tree algorithm. If null, default options are used.

regularization IRegularization<T, Matrix<T>, Vector<T>>

Regularization method to prevent overfitting. If null, no regularization is applied.

Remarks

This constructor creates a new M5 model tree with the specified options and regularization. If no options are provided, default values are used. Regularization helps prevent overfitting by penalizing complex models.

For Beginners: This creates a new M5 model tree with your chosen settings.

When creating an M5 model tree:

  • You can provide custom settings (options) or use the defaults
  • You can add regularization, which helps prevent the model from memorizing the training data too closely

Regularization is like adding guardrails that prevent the model from becoming too complex or fitting too closely to the training data, which helps it perform better on new data.

Methods

CalculateFeatureImportancesAsync(int)

Asynchronously calculates the importance of each feature in the model.

protected override Task CalculateFeatureImportancesAsync(int featureCount)

Parameters

featureCount int

The total number of features.

Returns

Task

A task representing the asynchronous operation.

Remarks

This method calculates the importance of each feature in the decision tree by assigning weights to nodes based on their position in the tree. Features used closer to the root receive higher importance scores. The resulting feature importance values are normalized to sum to 1.

For Beginners: This method figures out which features are most important for predictions.

The process:

  • It examines the whole tree structure
  • Features used near the top of the tree (root) are considered more important
  • Features used multiple times throughout the tree gain importance
  • The values are adjusted so they add up to 1 (or 100%)

This helps you understand which factors have the biggest impact on predictions. For example, you might learn that for house prices, location affects the prediction more than the number of bedrooms.

CreateNewInstance()

Creates a new instance of the M5ModelTree with the same configuration as the current instance.

protected override IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()

Returns

IFullModel<T, Matrix<T>, Vector<T>>

A new M5ModelTree instance with the same options and regularization as the current instance.

Remarks

This method implements the abstract method from the base class, allowing the creation of a new model with the same configuration options and regularization settings. This is useful for model cloning, ensemble methods, or cross-validation scenarios where multiple instances of the same model type with identical configurations are needed.

For Beginners: This method creates a copy of the model's blueprint.

When you need multiple versions of the same type of model with identical settings:

  • This method creates a new, empty model with the same configuration
  • It's like making a copy of a recipe before you start cooking
  • The new model has the same settings but no trained data
  • This is useful for techniques that need multiple models, like cross-validation

For example, if you're testing your model on different subsets of data, you'd want each test to use a model with identical settings.

Deserialize(byte[])

Deserializes the M5 model tree from a byte array, including linear models at leaf nodes.

public override void Deserialize(byte[] modelData)

Parameters

modelData byte[]

GetModelMetadata()

Gets metadata about the trained model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetadata object containing information about the model.

Remarks

This method creates a metadata object containing information about the trained model, including the model type, hyperparameters, and feature importances. This metadata can be useful for model inspection, comparison, and serialization.

For Beginners: This method provides a summary of the model and its settings.

The metadata includes:

  • The type of model (M5ModelTree)
  • All the settings used to create the model
  • Information about the importance of each feature
  • The type of regularization used (if any)

This is useful for:

  • Documenting how the model was built
  • Comparing different models
  • Sharing model information with others
  • Saving important details along with the model

PredictAsync(Matrix<T>)

Asynchronously generates predictions for new data points using the trained M5 model tree.

public override Task<Vector<T>> PredictAsync(Matrix<T> input)

Parameters

input Matrix<T>

The feature matrix where each row is a sample to predict.

Returns

Task<Vector<T>>

A task that represents the asynchronous operation, containing a vector of predicted values.

Remarks

This method traverses the tree for each input sample, following the decision path until reaching a leaf node. At the leaf, it either uses the stored constant value or applies the linear regression model to generate the prediction. The predictions for multiple samples are processed in parallel for improved performance.

For Beginners: This is where the model makes predictions on new data.

For each data point:

  1. The model follows the decision tree path based on the feature values
  2. When it reaches a leaf node, it either:
    • Returns the average value for that leaf (simple approach)
    • Uses a mini-calculator (linear model) for a more precise prediction
  3. All data points are processed at the same time (in parallel) for speed

For example, when predicting house prices, each house's features guide it through different paths in the tree until reaching the appropriate pricing model for that type of house.

Serialize()

Serializes the M5 model tree to a byte array, including linear models at leaf nodes.

public override byte[] Serialize()

Returns

byte[]

TrainAsync(Matrix<T>, Vector<T>)

Asynchronously trains the M5 model tree using the provided features and target values.

public override Task TrainAsync(Matrix<T> x, Vector<T> y)

Parameters

x Matrix<T>

The feature matrix where each row is a sample and each column is a feature.

y Vector<T>

The target vector containing the continuous values to predict.

Returns

Task

A task representing the asynchronous training operation.

Remarks

This method builds the M5 model tree structure recursively, finding the best feature splits at each node to minimize prediction error. If enabled, it applies pruning to reduce tree complexity and prevent overfitting. Finally, it calculates feature importances to provide insights into which features are most influential in making predictions.

For Beginners: This is where the model learns from your data.

During training:

  1. The tree is built from top to bottom by finding the best questions to ask about your data
  2. If pruning is enabled, the tree is simplified by removing unnecessary branches
  3. The model calculates which features are most important for predictions

The "Async" in the name means this method can run efficiently without blocking other operations, which is especially helpful when training with large datasets.