Class RandomForestRegression<T>
- Namespace
- AiDotNet.Regression
- Assembly
- AiDotNet.dll
Implements Random Forest Regression, an ensemble learning method that operates by constructing multiple decision trees during training and outputting the average prediction of the individual trees.
public class RandomForestRegression<T> : AsyncDecisionTreeRegressionBase<T>, IAsyncTreeBasedModel<T>, ITreeBasedRegression<T>, INonLinearRegression<T>, IRegression<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric data type used for calculations (e.g., float, double).
- Inheritance
-
RandomForestRegression<T>
- Implements
-
IRegression<T>
- Inherited Members
- Extension Methods
Remarks
Random Forest Regression combines multiple decision trees to improve prediction accuracy and control overfitting. Each tree is trained on a bootstrap sample of the training data, and at each node, only a random subset of features is considered for splitting. The final prediction is the average of predictions from all trees.
The algorithm's key strengths include robustness to outliers, good performance on high-dimensional data, and the ability to capture non-linear relationships without requiring extensive hyperparameter tuning.
For Beginners: Think of Random Forest as a committee of decision trees, where each tree votes on the prediction. By combining many trees, each trained slightly differently, the model becomes more robust and accurate than any single tree. It's like asking multiple experts for their opinion and taking the average.
Constructors
RandomForestRegression(RandomForestRegressionOptions, IRegularization<T, Matrix<T>, Vector<T>>?)
Initializes a new instance of the RandomForestRegression class with the specified options and regularization.
public RandomForestRegression(RandomForestRegressionOptions options, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)
Parameters
optionsRandomForestRegressionOptionsConfiguration options for the Random Forest regression model.
regularizationIRegularization<T, Matrix<T>, Vector<T>>Regularization method to prevent overfitting. If null, no regularization will be applied.
Remarks
The constructor initializes the model with the provided options and sets up the random number generator.
For Beginners: This constructor sets up the Random Forest model with your specified settings. The options control things like how many trees to build, how deep each tree can be, and how many features to consider at each split. Regularization is an optional technique to prevent the model from becoming too complex and overfitting to the training data.
Properties
MaxDepth
Gets the maximum depth of the trees in the forest.
public override int MaxDepth { get; }
Property Value
- int
The maximum depth specified in the options.
NumberOfTrees
Gets the number of trees in the forest.
public override int NumberOfTrees { get; }
Property Value
- int
The number of trees specified in the options.
SupportsJitCompilation
Gets whether this Random Forest model supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
truewhen soft tree mode is enabled and all trees have been trained;falseotherwise.
Remarks
Random Forest supports JIT compilation when soft tree mode is enabled. In soft mode, each tree in the forest uses sigmoid-based soft gating instead of hard if-then splits, making the entire ensemble differentiable.
For Beginners: JIT compilation is available when soft tree mode is enabled.
In soft tree mode:
- Each tree in the forest uses smooth transitions instead of hard decisions
- All trees can be exported as a single computation graph
- The final prediction averages all tree outputs (just like regular Random Forest)
This gives you the benefits of ensemble learning with JIT-compiled speed.
Methods
CalculateFeatureImportancesAsync(int)
Asynchronously calculates the importance of each feature in the model.
protected override Task CalculateFeatureImportancesAsync(int numFeatures)
Parameters
numFeaturesintThe number of features in the input data.
Returns
- Task
A task that represents the asynchronous calculation operation.
Remarks
This method calculates feature importances by averaging the importances across all trees in the forest and then normalizing them so they sum to 1.
For Beginners: Feature importance tells you which input variables have the most influence on the predictions. In Random Forests, this is calculated by measuring how much each feature reduces the prediction error when used in the trees. Higher values indicate more important features. The importances are normalized to sum to 1, so you can interpret them as percentages of total importance.
CreateNewInstance()
Creates a new instance of the Random Forest regression model with the same options.
protected override IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same configuration but no trained parameters.
Remarks
This method creates a new instance of the Random Forest regression model with the same configuration options and regularization method as the current instance, but without copying the trained trees or other learned parameters.
For Beginners: This method creates a fresh copy of the model configuration without any learned parameters.
Think of it like getting a blank forest template with the same settings, but without any of the trained trees. The new model has the same:
- Number of trees setting
- Maximum depth setting
- Minimum samples split setting
- Maximum features ratio
- Split criterion (how nodes decide which feature to split on)
- Regularization method
But it doesn't have any of the actual trained trees that were learned from data.
This is mainly used internally when doing things like cross-validation or creating ensembles of similar models with different training data.
Deserialize(byte[])
Deserializes the model from a byte array.
public override void Deserialize(byte[] data)
Parameters
databyte[]The byte array containing the serialized model data.
Remarks
This method reconstructs the model's parameters from a serialized byte array, including options, trees, and regularization type.
For Beginners: Deserialization is the opposite of serialization - it takes the saved model data and reconstructs the model's internal state. This allows you to load a previously trained model and use it to make predictions without having to retrain it. It's like loading a saved game to continue where you left off.
Exceptions
- InvalidOperationException
Thrown when deserialization fails.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the Random Forest's computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The root node of the exported computation graph.
Remarks
When soft tree mode is enabled, this exports the entire Random Forest as a differentiable computation graph. Each tree is exported individually, and their outputs are averaged to produce the final prediction.
The computation graph structure is:
output = (tree1_output + tree2_output + ... + treeN_output) / N
where each tree uses soft split operations.
For Beginners: This exports the entire forest as a computation graph.
Each tree in the forest becomes a soft tree computation graph, and then all tree outputs are averaged together - just like how regular Random Forest predictions work, but compiled into optimized code.
Exceptions
- NotSupportedException
Thrown when soft tree mode is not enabled.
- InvalidOperationException
Thrown when the forest has not been trained (no trees).
GetModelMetadata()
Gets metadata about the model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing information about the model.
Remarks
This method returns metadata about the model, including its type, number of trees, maximum depth, minimum samples to split, maximum features, feature importances, and regularization type.
For Beginners: Model metadata provides information about the model itself, rather than the predictions it makes. This includes details about how the model is configured (like how many trees it uses and how deep they are) and information about the importance of different features. This can help you understand which input variables are most influential in making predictions.
PredictAsync(Matrix<T>)
Asynchronously makes predictions for the given input data.
public override Task<Vector<T>> PredictAsync(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is an example and each column is a feature.
Returns
- Task<Vector<T>>
A task that represents the asynchronous prediction operation, containing a vector of predicted values.
Remarks
This method makes predictions by averaging the predictions from all trees in the forest. The steps are: 1. Apply regularization to the input matrix 2. Get predictions from all trees in parallel 3. Average the predictions for each input example 4. Apply regularization to the averaged predictions
For Beginners: After training, this method is used to make predictions on new data. It gets a prediction from each tree in the forest and then averages these predictions to produce the final result. This averaging helps to reduce the variance (randomness) in the predictions, making the model more stable and accurate than any single decision tree.
Serialize()
Serializes the model to a byte array.
public override byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model data.
Remarks
This method serializes the model's parameters, including options, trees, and regularization type, to a JSON format and then converts it to a byte array.
For Beginners: Serialization converts the model's internal state into a format that can be saved to disk or transmitted over a network. This allows you to save a trained model and load it later without having to retrain it. Think of it like saving your progress in a video game.
TrainAsync(Matrix<T>, Vector<T>)
Asynchronously trains the Random Forest regression model on the provided data.
public override Task TrainAsync(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>The input features matrix where each row is a training example and each column is a feature.
yVector<T>The target values vector corresponding to each training example.
Returns
- Task
A task that represents the asynchronous training operation.
Remarks
This method builds multiple decision trees in parallel, each trained on a bootstrap sample of the training data and considering a random subset of features at each split. The steps are: 1. Clear any existing trees 2. Calculate the number of features to consider at each split 3. For each tree: a. Generate a bootstrap sample of the training data b. Create a new decision tree with the specified options c. Train the tree on the bootstrap sample 4. Calculate feature importances by averaging across all trees
For Beginners: Training is the process where the model learns from your data. The algorithm builds multiple decision trees, each on a slightly different version of your data (created by random sampling with replacement). Each tree also considers only a random subset of features at each split, which helps to make the trees more diverse. By building many diverse trees and combining their predictions, the model can capture complex relationships and provide more robust predictions than a single tree.