Class DecisionTreeRegressionBase<T>
- Namespace
- AiDotNet.Regression
- Assembly
- AiDotNet.dll
Provides a base implementation for decision tree regression models that predict continuous values.
public abstract class DecisionTreeRegressionBase<T> : ITreeBasedRegression<T>, INonLinearRegression<T>, IRegression<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
DecisionTreeRegressionBase<T>
- Implements
-
IRegression<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
This abstract class implements common functionality for decision tree regression models, providing a framework for building predictive models based on decision trees. It manages the tree structure, handles serialization and deserialization, and defines the interface that concrete implementations must support.
For Beginners: This is a template for creating decision tree models that predict numerical values.
A decision tree works like a flowchart of yes/no questions to make predictions:
- Start at the top (root) of the tree
- At each step, answer a question about your data
- Follow the appropriate path based on your answer
- Continue until you reach an endpoint that provides a prediction
This base class provides the common structure and behaviors that all decision tree models share, while allowing specific implementations to customize how the tree is built and used.
Think of it like a blueprint for building different types of decision trees, where specific implementations can fill in the details according to their requirements.
Constructors
DecisionTreeRegressionBase(DecisionTreeOptions?, IRegularization<T, Matrix<T>, Vector<T>>?, ILossFunction<T>?)
Initializes a new instance of the DecisionTreeRegressionBase<T> class.
protected DecisionTreeRegressionBase(DecisionTreeOptions? options, IRegularization<T, Matrix<T>, Vector<T>>? regularization, ILossFunction<T>? lossFunction = null)
Parameters
optionsDecisionTreeOptionsOptional configuration options for the decision tree algorithm.
regularizationIRegularization<T, Matrix<T>, Vector<T>>Optional regularization strategy to prevent overfitting.
lossFunctionILossFunction<T>Loss function for gradient computation. If null, defaults to Mean Squared Error.
Remarks
This constructor initializes a new base class for decision tree regression with the specified options and regularization strategy. If no options are provided, default values are used. If no regularization is specified, no regularization is applied.
For Beginners: This sets up the foundation for a decision tree model.
When creating a decision tree, you can specify three main things:
- Options: Controls how the tree grows (like its maximum depth or minimum samples needed to split)
- Regularization: Helps prevent the model from becoming too complex and "memorizing" the training data
- Loss Function: Determines how prediction errors are measured (defaults to Mean Squared Error)
If you don't specify these parameters, the model will use reasonable default settings.
This constructor is typically not called directly but is used by specific implementations of decision tree models.
Fields
NumOps
Provides operations for performing numeric calculations appropriate for the type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Remarks
This field contains implementations of basic mathematical operations (addition, subtraction, etc.) that work with the specific numeric type T. It allows the decision tree algorithm to perform calculations independently of the specific numeric type being used.
Root
The root node of the decision tree.
protected DecisionTreeNode<T>? Root
Field Value
Remarks
This field stores the root node of the decision tree structure. All predictions start from this node and follow a path through the tree based on the feature values of the input sample.
Properties
DefaultLossFunction
Gets the default loss function used by this model for gradient computation.
public virtual ILossFunction<T> DefaultLossFunction { get; }
Property Value
Remarks
This loss function is used when calling ComputeGradients(TInput, TOutput, ILossFunction<T>?) without explicitly providing a loss function. It represents the model's primary training objective.
For Beginners: The loss function tells the model "what counts as a mistake". For example: - For regression (predicting numbers): Mean Squared Error measures how far predictions are from actual values - For classification (predicting categories): Cross Entropy measures how confident the model is in the right category
This property provides a sensible default so you don't have to specify the loss function every time, but you can still override it if needed for special cases.
Distributed Training: In distributed training, all workers use the same loss function to ensure consistent gradient computation. The default loss function is automatically used when workers compute local gradients.
Exceptions
- InvalidOperationException
Thrown if accessed before the model has been configured with a loss function.
Engine
Gets the global execution engine for vector operations.
protected IEngine Engine { get; }
Property Value
- IEngine
FeatureImportances
Gets the importance scores for each feature used in the model.
public Vector<T> FeatureImportances { get; protected set; }
Property Value
- Vector<T>
A vector of values representing the relative importance of each feature, normalized to sum to 1.
Remarks
This property provides access to the calculated importance of each feature in the trained model. Feature importance scores indicate how useful or valuable each feature was in building the decision tree. Higher values indicate features that were more important for making predictions.
For Beginners: This property tells you which input features have the biggest impact on predictions.
Feature importance helps you understand:
- Which factors matter most for your predictions
- Which features might be redundant or irrelevant
- How the model is making its decisions
For example, when predicting house prices:
- Location might have importance 0.7 (very important)
- Square footage might have importance 0.2 (somewhat important)
- Year built might have importance 0.1 (less important)
These values always add up to 1, making it easy to compare the relative importance of different features.
FeatureNames
Gets or sets the feature names.
public string[]? FeatureNames { get; set; }
Property Value
- string[]
An array of feature names. If not set, feature indices will be used as names.
MaxDepth
Gets the maximum depth of the decision tree.
public int MaxDepth { get; }
Property Value
- int
The maximum number of levels in the tree, from the root to the deepest leaf.
Remarks
This property returns the maximum depth of the decision tree, which is one of the most important parameters for controlling the complexity of the model. Deeper trees can capture more complex patterns but are more prone to overfitting.
For Beginners: This property tells you how many levels of questions the tree can ask.
Think of MaxDepth as the maximum number of questions that can be asked before making a prediction:
- A tree with MaxDepth = 1 can only ask one question (very simple model)
- A tree with MaxDepth = 10 can ask up to 10 nested questions (more complex model)
Setting an appropriate maximum depth helps prevent the model from becoming too complex:
- Too shallow (small MaxDepth): The model might be too simple to capture important patterns
- Too deep (large MaxDepth): The model might "memorize" the training data instead of learning general patterns, making it perform poorly on new data
NumberOfTrees
Gets the number of trees in this model, which is always 1 for a single decision tree.
public virtual int NumberOfTrees { get; }
Property Value
- int
The number of trees in the model, which is 1 for a standard decision tree implementation.
Remarks
This property returns the number of decision trees used in the model. For standard decision tree implementations, this is always 1. This property exists primarily for compatibility with ensemble methods that may use multiple trees.
For Beginners: This property indicates how many trees make up this model.
A single decision tree model always returns 1 here.
Some more advanced models (like Random Forests or Gradient Boosting) use multiple trees working together to make better predictions. In those cases, this property would return the number of trees in the ensemble.
Options
Gets the configuration options used by the decision tree algorithm.
protected DecisionTreeOptions Options { get; }
Property Value
Remarks
This property provides access to the configuration options that control how the decision tree is built, such as maximum depth, minimum samples required for splitting, and split criteria.
ParameterCount
Gets the number of parameters in the model.
public virtual int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
Regularization
Gets the regularization strategy applied to the model to prevent overfitting.
protected IRegularization<T, Matrix<T>, Vector<T>> Regularization { get; }
Property Value
- IRegularization<T, Matrix<T>, Vector<T>>
Remarks
This property provides access to the regularization strategy used to prevent the model from overfitting to the training data. Regularization helps improve the model's ability to generalize to new, unseen data.
SoftTreeTemperature
Gets or sets the temperature parameter for soft decision tree mode.
public T SoftTreeTemperature { get; set; }
Property Value
- T
The temperature for sigmoid gating. Lower values produce sharper decisions. Default is 1.0.
SupportsJitCompilation
Gets a value indicating whether this model supports JIT (Just-In-Time) compilation.
public virtual bool SupportsJitCompilation { get; }
Property Value
Remarks
When UseSoftTree is enabled, the decision tree can be exported as a differentiable computation graph using soft (sigmoid-based) gating. This enables JIT compilation for optimized inference.
When UseSoftTree is disabled, JIT compilation is not supported because traditional hard decision trees use branching logic that cannot be represented as a static computation graph.
UseSoftTree
Gets or sets whether to use soft (differentiable) tree mode for JIT compilation.
public bool UseSoftTree { get; set; }
Property Value
- bool
trueto enable soft tree mode with sigmoid gating for JIT support;false(default) for traditional hard decision tree.
Remarks
Soft Decision Trees: Instead of hard branching (if-then-else), soft trees use sigmoid gating to compute a smooth probability of going left or right at each node. This makes the tree differentiable and JIT-compilable.
Formula: p_left = σ((threshold - x[feature]) / temperature)
Output: weighted_output = p_left * left_value + (1 - p_left) * right_value
Trade-offs:
- Soft trees are differentiable and JIT-compilable
- Results are smooth approximations of hard decisions
- Lower temperature = sharper (closer to hard) decisions
- Higher temperature = softer (more averaged) decisions
Methods
ApplyGradients(Vector<T>, T)
Applies pre-computed gradients to update the model parameters.
public virtual void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>The gradient vector to apply.
learningRateTThe learning rate for the update.
Remarks
Updates parameters using: θ = θ - learningRate * gradients
For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.
Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.
CalculateFeatureImportances(int)
Calculates the importance scores for all features used in the model.
protected abstract void CalculateFeatureImportances(int featureCount)
Parameters
featureCountintThe number of features in the model.
Remarks
This abstract method must be implemented by derived classes to calculate the importance of each feature in the trained model. Feature importance indicates how valuable each feature was in building the decision tree.
For Beginners: This method figures out which input features matter most for predictions.
Different decision tree implementations might calculate feature importance in different ways, but the general idea is to measure how much each feature contributes to improving predictions when it's used in decision nodes throughout the tree.
After this method runs, the FeatureImportances property will contain a score for each feature, allowing you to see which features have the biggest impact on your model's predictions.
Clone()
Creates a clone of the decision tree model.
public virtual IFullModel<T, Matrix<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and tree structure.
Remarks
This method creates a complete copy of the decision tree model, including the entire tree structure and all learned parameters. Specific implementations should override this method to ensure all implementation-specific properties are properly copied.
For Beginners: This method creates an exact independent copy of your model.
Cloning a model means creating a new model that's exactly the same as the original, including all its learned parameters and settings. However, the clone is independent - changes to one model won't affect the other.
Think of it like photocopying a document - the copy has all the same information, but you can mark up the copy without changing the original.
Note: Specific decision tree algorithms will customize this method to ensure all their unique properties are properly copied.
ComputeGradients(Matrix<T>, Vector<T>, ILossFunction<T>?)
Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.
public virtual Vector<T> ComputeGradients(Matrix<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputMatrix<T>The input data.
targetVector<T>The target/expected output.
lossFunctionILossFunction<T>The loss function to use for gradient computation. If null, uses the model's default loss function.
Returns
- Vector<T>
A vector containing gradients with respect to all model parameters.
Remarks
This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.
Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.
For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.
For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)
Exceptions
- InvalidOperationException
If lossFunction is null and the model has no default loss function.
CreateNewInstance()
Creates a new instance of the decision tree model with the same options.
protected abstract IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same configuration but no trained parameters.
Remarks
This abstract method must be implemented by derived classes to create a new instance of the specific decision tree model type with the same configuration options but without copying the trained parameters.
For Beginners: This method creates a fresh copy of the model configuration without any learned parameters.
Think of it like getting a blank notepad with the same paper quality and size, but without any writing on it yet. The new model has the same:
- Maximum depth setting
- Minimum samples split setting
- Split criterion (how nodes decide which feature to split on)
- Other configuration options
But it doesn't have any of the actual tree structure that was learned from data.
This is mainly used internally when doing things like cross-validation or creating ensembles of similar models with different training data.
DeepCopy()
Creates a deep copy of the decision tree model.
public virtual IFullModel<T, Matrix<T>, Vector<T>> DeepCopy()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and tree structure.
Remarks
This method creates a complete copy of the decision tree model, including all nodes, connections, and learned parameters. The copy is independent of the original model, so modifications to one don't affect the other.
For Beginners: This method creates an exact independent copy of your model.
The copy has the same:
- Tree structure (all nodes and their connections)
- Decision rules (which features to split on and at what values)
- Prediction values at leaf nodes
- Feature importance scores
But it's completely separate from the original model - changes to one won't affect the other.
This is useful when you want to:
- Experiment with modifying a model without affecting the original
- Create multiple similar models to use in different contexts
- Save a "checkpoint" of your model before making changes
Deserialize(byte[])
Loads a previously serialized decision tree model from a byte array.
public virtual void Deserialize(byte[] modelData)
Parameters
modelDatabyte[]The byte array containing the serialized model.
Remarks
This method reconstructs a decision tree model from a byte array that was previously created using the Serialize method. It restores the model's configuration options, feature importances, and tree structure, allowing the model to be used for predictions without retraining.
For Beginners: This method loads a previously saved model from a sequence of bytes.
Deserialization allows you to:
- Load a model that was saved earlier
- Use a model without having to retrain it
- Share models between different applications
When you deserialize a model:
- All settings are restored
- Feature importances are recovered
- The entire tree structure is reconstructed
- The model is ready to make predictions immediately
Example:
// Load from a file
byte[] modelData = File.ReadAllBytes("decisionTree.model");
// Deserialize the model
decisionTree.Deserialize(modelData);
// Now you can use the model for predictions
var predictions = decisionTree.Predict(newFeatures);
ExportComputationGraph(List<ComputationNode<T>>)
Exports the model's computation as a graph of operations.
public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>The input nodes for the computation graph.
Returns
- ComputationNode<T>
The root node of the exported computation graph.
Remarks
When soft tree mode is enabled, this exports the tree as a differentiable computation graph using SoftSplit(ComputationNode<T>, ComputationNode<T>, ComputationNode<T>, int, T, T?) operations. Each internal node becomes a soft split operation that computes sigmoid-weighted combinations of left and right subtree outputs.
Exceptions
- NotSupportedException
Thrown when UseSoftTree is false.
- InvalidOperationException
Thrown when the tree has not been trained (Root is null).
GetActiveFeatureIndices()
Gets the indices of all features that are used in the decision tree.
public virtual IEnumerable<int> GetActiveFeatureIndices()
Returns
- IEnumerable<int>
An enumerable collection of indices for features used in the tree.
Remarks
This method identifies all features that are used as split criteria in the decision tree. Features that don't appear in any decision node are not considered active.
For Beginners: This method tells you which input features are actually used in the tree.
Decision trees often don't use all available features - they select the most informative ones during training. This method returns the positions (indices) of features that are actually used in decision nodes throughout the tree.
For example, if your dataset has 10 features but the tree only uses features at positions 2, 5, and 7, this method would return [2, 5, 7].
This is useful for:
- Feature selection (identifying which features matter)
- Model simplification (removing unused features)
- Understanding which inputs actually affect the prediction
GetFeatureImportance()
Gets the feature importance scores as a dictionary.
public virtual Dictionary<string, T> GetFeatureImportance()
Returns
- Dictionary<string, T>
A dictionary mapping feature names to their importance scores.
GetModelMetadata()
Gets metadata about the decision tree model and its configuration.
public abstract ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing information about the model.
Remarks
This abstract method must be implemented by derived classes to provide metadata about the model, including its type and configuration options. This information can be useful for model management, comparison, and documentation purposes.
For Beginners: This method provides information about your decision tree model.
The metadata typically includes:
- The type of model (e.g., Decision Tree, Random Forest)
- Configuration settings (like maximum depth)
- Other relevant information about the model
This information is helpful when:
- Comparing different models
- Documenting your model's configuration
- Troubleshooting model performance
Each implementation of a decision tree will provide its own version of this method, returning the specific metadata relevant to that implementation.
GetParameters()
Gets the model parameters as a vector representation.
public virtual Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing a serialized representation of the decision tree structure.
Remarks
This method provides a vector representation of the decision tree model. Decision trees have a hierarchical structure that doesn't naturally fit into a flat vector format, so this representation is a simplified encoding of the tree structure suitable for certain optimization algorithms or model comparison techniques.
For Beginners: This method converts the tree structure into a flat list of numbers.
Decision trees are complex structures with branches and nodes, which don't naturally fit into a simple list of parameters like linear models do. This method creates a specialized representation of the tree that can be used by certain algorithms or for model comparison.
The exact format of this representation depends on the specific implementation, but generally includes information about:
- Each node's feature index (which feature it splits on)
- Each node's split value (the threshold for the decision)
- Each node's prediction value (for leaf nodes)
- The tree structure (how nodes connect to each other)
This is primarily used by advanced algorithms and not typically needed for regular use.
IsFeatureUsed(int)
Determines whether a specific feature is used in the decision tree.
public virtual bool IsFeatureUsed(int featureIndex)
Parameters
featureIndexintThe zero-based index of the feature to check.
Returns
- bool
True if the feature is used in at least one decision node; otherwise, false.
Remarks
This method checks whether a specific feature is used as a split criterion in any node of the decision tree.
For Beginners: This method checks if a specific input feature is used in the tree.
You provide the position (index) of a feature, and the method tells you whether that feature is used in any decision node throughout the tree.
For example, if feature #3 is never used to make a decision in the tree, this method would return false because that feature doesn't affect the model's predictions.
This is useful when you want to check a specific feature's importance rather than getting all important features at once.
LoadModel(string)
Loads the model from a file.
public virtual void LoadModel(string filePath)
Parameters
filePathstringThe path to the file containing the saved model.
Remarks
This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.
For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.
Exceptions
- FileNotFoundException
Thrown when the specified file does not exist.
- IOException
Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.
LoadState(Stream)
Loads the model's state from a stream.
public virtual void LoadState(Stream stream)
Parameters
streamStream
Predict(Matrix<T>)
Predicts target values for the provided input features using the trained decision tree model.
public abstract Vector<T> Predict(Matrix<T> input)
Parameters
inputMatrix<T>A matrix where each row represents a sample to predict and each column represents a feature.
Returns
- Vector<T>
A vector of predicted values corresponding to each input sample.
Remarks
This abstract method must be implemented by derived classes to predict target values for new input data using the trained decision tree model. The specific algorithm for making predictions is defined by the implementation.
For Beginners: This method uses your trained model to make predictions on new data.
You provide:
- input: New data points for which you want predictions
The model will use the decision tree it learned during training to predict values for each row of input data. The way it navigates the tree to make predictions will depend on the specific implementation of the decision tree model.
For example, if you trained the model to predict house prices, you could use this method to predict prices for a new set of houses based on their features.
SaveModel(string)
Saves the model to a file.
public virtual void SaveModel(string filePath)
Parameters
filePathstringThe path where the model should be saved.
Remarks
This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.
For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.
Exceptions
- IOException
Thrown when an I/O error occurs while writing to the file.
- UnauthorizedAccessException
Thrown when the caller does not have the required permission to write to the specified file path.
SaveState(Stream)
Saves the model's current state to a stream.
public virtual void SaveState(Stream stream)
Parameters
streamStream
Serialize()
Serializes the decision tree model to a byte array for storage or transmission.
public virtual byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model.
Remarks
This method converts the decision tree model into a byte array that can be stored in a file, database, or transmitted over a network. The serialized data includes the model's configuration options, feature importances, and the complete tree structure.
For Beginners: This method saves your trained model as a sequence of bytes.
Serialization allows you to:
- Save your model to a file
- Store your model in a database
- Send your model over a network
- Keep your model for later use without having to retrain it
The serialized data includes:
- All the model's settings (like maximum depth)
- The importance of each feature
- The entire tree structure with all its decision rules
Example:
// Serialize the model
byte[] modelData = decisionTree.Serialize();
// Save to a file
File.WriteAllBytes("decisionTree.model", modelData);
SetActiveFeatureIndices(IEnumerable<int>)
Sets the active feature indices for this model.
public virtual void SetActiveFeatureIndices(IEnumerable<int> featureIndices)
Parameters
featureIndicesIEnumerable<int>The indices of features to activate.
SetParameters(Vector<T>)
Sets the parameters for this model.
public virtual void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing the model parameters.
Train(Matrix<T>, Vector<T>)
Trains the decision tree model using the provided input features and target values.
public abstract void Train(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>A matrix where each row represents a sample and each column represents a feature.
yVector<T>A vector of target values corresponding to each sample in x.
Remarks
This abstract method must be implemented by derived classes to build a decision tree model using the provided training data. The specific algorithm for building the tree is defined by the implementation.
For Beginners: This method teaches the decision tree how to make predictions.
You provide:
- x: Your input data (features) - like house size, number of bedrooms, location, etc.
- y: The values you want to predict - like house prices
Each specific implementation of a decision tree will provide its own version of this method, which defines exactly how the tree learns from your data.
After training, the model will be ready to make predictions on new data.
WithParameters(Vector<T>)
Creates a new instance of the model with the specified parameters.
public virtual IFullModel<T, Matrix<T>, Vector<T>> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing a serialized representation of the decision tree structure.
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new model instance with the reconstructed tree structure.
Remarks
This method reconstructs a decision tree model from a parameter vector that was previously created using the GetParameters method. Due to the complex nature of tree structures, this reconstruction is approximate and is primarily intended for use with optimization algorithms or model comparison techniques.
For Beginners: This method rebuilds a decision tree from a flat list of numbers.
It takes the specialized vector representation created by GetParameters() and attempts to reconstruct a decision tree from it. This is challenging because decision trees are complex structures that don't easily convert to and from simple lists of numbers.
This method is primarily used by advanced algorithms and not typically needed for regular use. For most purposes, the Serialize and Deserialize methods provide a more reliable way to save and load tree models.