Table of Contents

Class RandomForestRegression<T>

Namespace
AiDotNet.Regression
Assembly
AiDotNet.dll

Implements Random Forest Regression, an ensemble learning method that operates by constructing multiple decision trees during training and outputting the average prediction of the individual trees.

public class RandomForestRegression<T> : AsyncDecisionTreeRegressionBase<T>, IAsyncTreeBasedModel<T>, ITreeBasedRegression<T>, INonLinearRegression<T>, IRegression<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T

The numeric data type used for calculations (e.g., float, double).

Inheritance
RandomForestRegression<T>
Implements
IFullModel<T, Matrix<T>, Vector<T>>
IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Matrix<T>, Vector<T>>
ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>
IGradientComputable<T, Matrix<T>, Vector<T>>
Inherited Members
Extension Methods

Remarks

Random Forest Regression combines multiple decision trees to improve prediction accuracy and control overfitting. Each tree is trained on a bootstrap sample of the training data, and at each node, only a random subset of features is considered for splitting. The final prediction is the average of predictions from all trees.

The algorithm's key strengths include robustness to outliers, good performance on high-dimensional data, and the ability to capture non-linear relationships without requiring extensive hyperparameter tuning.

For Beginners: Think of Random Forest as a committee of decision trees, where each tree votes on the prediction. By combining many trees, each trained slightly differently, the model becomes more robust and accurate than any single tree. It's like asking multiple experts for their opinion and taking the average.

Constructors

RandomForestRegression(RandomForestRegressionOptions, IRegularization<T, Matrix<T>, Vector<T>>?)

Initializes a new instance of the RandomForestRegression class with the specified options and regularization.

public RandomForestRegression(RandomForestRegressionOptions options, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)

Parameters

options RandomForestRegressionOptions

Configuration options for the Random Forest regression model.

regularization IRegularization<T, Matrix<T>, Vector<T>>

Regularization method to prevent overfitting. If null, no regularization will be applied.

Remarks

The constructor initializes the model with the provided options and sets up the random number generator.

For Beginners: This constructor sets up the Random Forest model with your specified settings. The options control things like how many trees to build, how deep each tree can be, and how many features to consider at each split. Regularization is an optional technique to prevent the model from becoming too complex and overfitting to the training data.

Properties

MaxDepth

Gets the maximum depth of the trees in the forest.

public override int MaxDepth { get; }

Property Value

int

The maximum depth specified in the options.

NumberOfTrees

Gets the number of trees in the forest.

public override int NumberOfTrees { get; }

Property Value

int

The number of trees specified in the options.

SupportsJitCompilation

Gets whether this Random Forest model supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

true when soft tree mode is enabled and all trees have been trained; false otherwise.

Remarks

Random Forest supports JIT compilation when soft tree mode is enabled. In soft mode, each tree in the forest uses sigmoid-based soft gating instead of hard if-then splits, making the entire ensemble differentiable.

For Beginners: JIT compilation is available when soft tree mode is enabled.

In soft tree mode:

  • Each tree in the forest uses smooth transitions instead of hard decisions
  • All trees can be exported as a single computation graph
  • The final prediction averages all tree outputs (just like regular Random Forest)

This gives you the benefits of ensemble learning with JIT-compiled speed.

Methods

CalculateFeatureImportancesAsync(int)

Asynchronously calculates the importance of each feature in the model.

protected override Task CalculateFeatureImportancesAsync(int numFeatures)

Parameters

numFeatures int

The number of features in the input data.

Returns

Task

A task that represents the asynchronous calculation operation.

Remarks

This method calculates feature importances by averaging the importances across all trees in the forest and then normalizing them so they sum to 1.

For Beginners: Feature importance tells you which input variables have the most influence on the predictions. In Random Forests, this is calculated by measuring how much each feature reduces the prediction error when used in the trees. Higher values indicate more important features. The importances are normalized to sum to 1, so you can interpret them as percentages of total importance.

CreateNewInstance()

Creates a new instance of the Random Forest regression model with the same options.

protected override IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()

Returns

IFullModel<T, Matrix<T>, Vector<T>>

A new instance of the model with the same configuration but no trained parameters.

Remarks

This method creates a new instance of the Random Forest regression model with the same configuration options and regularization method as the current instance, but without copying the trained trees or other learned parameters.

For Beginners: This method creates a fresh copy of the model configuration without any learned parameters.

Think of it like getting a blank forest template with the same settings, but without any of the trained trees. The new model has the same:

  • Number of trees setting
  • Maximum depth setting
  • Minimum samples split setting
  • Maximum features ratio
  • Split criterion (how nodes decide which feature to split on)
  • Regularization method

But it doesn't have any of the actual trained trees that were learned from data.

This is mainly used internally when doing things like cross-validation or creating ensembles of similar models with different training data.

Deserialize(byte[])

Deserializes the model from a byte array.

public override void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method reconstructs the model's parameters from a serialized byte array, including options, trees, and regularization type.

For Beginners: Deserialization is the opposite of serialization - it takes the saved model data and reconstructs the model's internal state. This allows you to load a previously trained model and use it to make predictions without having to retrain it. It's like loading a saved game to continue where you left off.

Exceptions

InvalidOperationException

Thrown when deserialization fails.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the Random Forest's computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The root node of the exported computation graph.

Remarks

When soft tree mode is enabled, this exports the entire Random Forest as a differentiable computation graph. Each tree is exported individually, and their outputs are averaged to produce the final prediction.

The computation graph structure is:

output = (tree1_output + tree2_output + ... + treeN_output) / N
where each tree uses soft split operations.

For Beginners: This exports the entire forest as a computation graph.

Each tree in the forest becomes a soft tree computation graph, and then all tree outputs are averaged together - just like how regular Random Forest predictions work, but compiled into optimized code.

Exceptions

NotSupportedException

Thrown when soft tree mode is not enabled.

InvalidOperationException

Thrown when the forest has not been trained (no trees).

GetModelMetadata()

Gets metadata about the model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetadata object containing information about the model.

Remarks

This method returns metadata about the model, including its type, number of trees, maximum depth, minimum samples to split, maximum features, feature importances, and regularization type.

For Beginners: Model metadata provides information about the model itself, rather than the predictions it makes. This includes details about how the model is configured (like how many trees it uses and how deep they are) and information about the importance of different features. This can help you understand which input variables are most influential in making predictions.

PredictAsync(Matrix<T>)

Asynchronously makes predictions for the given input data.

public override Task<Vector<T>> PredictAsync(Matrix<T> input)

Parameters

input Matrix<T>

The input features matrix where each row is an example and each column is a feature.

Returns

Task<Vector<T>>

A task that represents the asynchronous prediction operation, containing a vector of predicted values.

Remarks

This method makes predictions by averaging the predictions from all trees in the forest. The steps are: 1. Apply regularization to the input matrix 2. Get predictions from all trees in parallel 3. Average the predictions for each input example 4. Apply regularization to the averaged predictions

For Beginners: After training, this method is used to make predictions on new data. It gets a prediction from each tree in the forest and then averages these predictions to produce the final result. This averaging helps to reduce the variance (randomness) in the predictions, making the model more stable and accurate than any single decision tree.

Serialize()

Serializes the model to a byte array.

public override byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method serializes the model's parameters, including options, trees, and regularization type, to a JSON format and then converts it to a byte array.

For Beginners: Serialization converts the model's internal state into a format that can be saved to disk or transmitted over a network. This allows you to save a trained model and load it later without having to retrain it. Think of it like saving your progress in a video game.

TrainAsync(Matrix<T>, Vector<T>)

Asynchronously trains the Random Forest regression model on the provided data.

public override Task TrainAsync(Matrix<T> x, Vector<T> y)

Parameters

x Matrix<T>

The input features matrix where each row is a training example and each column is a feature.

y Vector<T>

The target values vector corresponding to each training example.

Returns

Task

A task that represents the asynchronous training operation.

Remarks

This method builds multiple decision trees in parallel, each trained on a bootstrap sample of the training data and considering a random subset of features at each split. The steps are: 1. Clear any existing trees 2. Calculate the number of features to consider at each split 3. For each tree: a. Generate a bootstrap sample of the training data b. Create a new decision tree with the specified options c. Train the tree on the bootstrap sample 4. Calculate feature importances by averaging across all trees

For Beginners: Training is the process where the model learns from your data. The algorithm builds multiple decision trees, each on a slightly different version of your data (created by random sampling with replacement). Each tree also considers only a random subset of features at each split, which helps to make the trees more diverse. By building many diverse trees and combining their predictions, the model can capture complex relationships and provide more robust predictions than a single tree.