Class AsyncDecisionTreeRegressionBase<T>
- Namespace
- AiDotNet.Regression
- Assembly
- AiDotNet.dll
Represents an abstract base class for asynchronous decision tree regression models.
public abstract class AsyncDecisionTreeRegressionBase<T> : IAsyncTreeBasedModel<T>, ITreeBasedRegression<T>, INonLinearRegression<T>, IRegression<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
AsyncDecisionTreeRegressionBase<T>
- Implements
-
IRegression<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
This class provides a foundation for implementing decision tree regression models that can be trained and used for predictions asynchronously. It includes methods for training, prediction, serialization, and deserialization of the model.
For Beginners: A decision tree is a type of machine learning model that makes predictions by following a series of yes/no questions about the input data. It's like a flowchart that helps the computer decide what prediction to make.
For example, if you're trying to predict if it will rain:
- Is the humidity high? If yes, go to next question. If no, predict no rain.
- Are there clouds? If yes, predict rain. If no, predict no rain.
This class provides the basic structure for building these types of models, but with more complex questions and answers based on numerical data.
Constructors
AsyncDecisionTreeRegressionBase(DecisionTreeOptions?, IRegularization<T, Matrix<T>, Vector<T>>?, ILossFunction<T>?)
Initializes a new instance of the AsyncDecisionTreeRegressionBase class.
protected AsyncDecisionTreeRegressionBase(DecisionTreeOptions? options, IRegularization<T, Matrix<T>, Vector<T>>? regularization, ILossFunction<T>? lossFunction = null)
Parameters
optionsDecisionTreeOptionsThe options for configuring the decision tree.
regularizationIRegularization<T, Matrix<T>, Vector<T>>The regularization method to use.
lossFunctionILossFunction<T>Loss function for gradient computation. If null, defaults to Mean Squared Error.
Fields
NumOps
Gets the numeric operations for the type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Root
Gets or sets the root node of the decision tree.
protected DecisionTreeNode<T>? Root
Field Value
Properties
DefaultLossFunction
Gets the default loss function used by this model for gradient computation.
public virtual ILossFunction<T> DefaultLossFunction { get; }
Property Value
Remarks
This loss function is used when calling ComputeGradients(TInput, TOutput, ILossFunction<T>?) without explicitly providing a loss function. It represents the model's primary training objective.
For Beginners: The loss function tells the model "what counts as a mistake". For example: - For regression (predicting numbers): Mean Squared Error measures how far predictions are from actual values - For classification (predicting categories): Cross Entropy measures how confident the model is in the right category
This property provides a sensible default so you don't have to specify the loss function every time, but you can still override it if needed for special cases.
Distributed Training: In distributed training, all workers use the same loss function to ensure consistent gradient computation. The default loss function is automatically used when workers compute local gradients.
Exceptions
- InvalidOperationException
Thrown if accessed before the model has been configured with a loss function.
Engine
Gets the global execution engine for vector operations.
protected IEngine Engine { get; }
Property Value
- IEngine
FeatureImportances
Gets or sets the importance of each feature in making predictions.
public Vector<T> FeatureImportances { get; protected set; }
Property Value
- Vector<T>
Remarks
For Beginners: Feature importance tells you how much each input variable (feature) contributes to the model's predictions. Higher values mean that feature is more important.
For example, if predicting house prices:
- Location might have high importance (big impact on price)
- Wall color might have low importance (small impact on price)
FeatureNames
Gets or sets the feature names.
public string[]? FeatureNames { get; set; }
Property Value
- string[]
An array of feature names. If not set, feature indices will be used as names.
MaxDepth
Gets the maximum depth of the decision tree.
public virtual int MaxDepth { get; }
Property Value
Remarks
For Beginners: The maximum depth is the longest path from the root of the tree to a leaf. A deeper tree can capture more complex patterns but may also overfit to the training data.
Imagine a game of "20 Questions":
- A shallow tree is like having only 5 questions (less detailed, but quicker).
- A deep tree is like having all 20 questions (more detailed, but might be too specific).
NumberOfTrees
Gets the number of trees in the model. For a single decision tree, this is always 1.
public virtual int NumberOfTrees { get; }
Property Value
Options
Gets the options used to configure the decision tree.
protected DecisionTreeOptions Options { get; }
Property Value
ParameterCount
Gets the number of parameters in the model.
public virtual int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
Random
Random number generator used for tree building and sampling.
protected Random Random { get; }
Property Value
Regularization
Gets the regularization method used to prevent overfitting.
protected IRegularization<T, Matrix<T>, Vector<T>> Regularization { get; }
Property Value
- IRegularization<T, Matrix<T>, Vector<T>>
Remarks
For Beginners: Regularization is a technique used to prevent the model from becoming too complex and fitting the training data too closely. This helps the model generalize better to new, unseen data.
Think of it like learning to ride a bike:
- Without regularization, you might only learn to ride on one specific path.
- With regularization, you learn general bike-riding skills that work on many different paths.
SoftTreeTemperature
Gets or sets the temperature parameter for soft decision tree mode.
public T SoftTreeTemperature { get; set; }
Property Value
- T
The temperature for sigmoid gating. Lower values produce sharper decisions. Default is 1.0.
Remarks
Only used when UseSoftTree is enabled. Controls the smoothness of the soft split operations:
- Lower temperature (e.g., 0.1) = sharper, more discrete decisions
- Higher temperature (e.g., 10.0) = softer, more blended decisions
SupportsJitCompilation
Gets whether this model currently supports JIT compilation.
public virtual bool SupportsJitCompilation { get; }
Property Value
- bool
truewhen UseSoftTree is enabled and the tree has been trained;falseotherwise.
Remarks
When UseSoftTree is enabled, the decision tree can be exported as a differentiable computation graph using soft (sigmoid-based) gating. This enables JIT compilation for optimized inference.
When UseSoftTree is disabled, JIT compilation is not supported because traditional hard decision trees use branching logic that cannot be represented as a static computation graph.
For Beginners: JIT compilation is available when soft tree mode is enabled.
In soft tree mode, the discrete if-then decisions are replaced with smooth sigmoid functions that can be compiled into an optimized computation graph. This gives you the interpretability of decision trees with the speed of JIT-compiled models.
UseSoftTree
Gets or sets whether to use soft (differentiable) tree mode for JIT compilation support.
public bool UseSoftTree { get; set; }
Property Value
- bool
trueto enable soft tree mode;false(default) for traditional hard decision trees.
Remarks
When enabled, the decision tree uses sigmoid-based soft gating instead of hard if-then splits. This makes the tree differentiable and enables JIT compilation support.
Formula at each split: output = σ((threshold - x[feature]) / temperature) * left + (1 - σ) * right where σ is the sigmoid function.
For Beginners: Soft tree mode allows the decision tree to be JIT compiled for faster inference.
Traditional decision trees make hard yes/no decisions:
- "If feature > 5, go LEFT, otherwise go RIGHT"
Soft trees use smooth transitions instead:
- Near the boundary, the output blends both left and right paths
- This creates a smooth, differentiable function that can be JIT compiled
Methods
ApplyGradients(Vector<T>, T)
Applies pre-computed gradients to update the model parameters.
public virtual void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>The gradient vector to apply.
learningRateTThe learning rate for the update.
Remarks
Updates parameters using: θ = θ - learningRate * gradients
For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.
Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.
CalculateFeatureImportancesAsync(int)
Asynchronously calculates the importance of each feature in the model.
protected abstract Task CalculateFeatureImportancesAsync(int featureCount)
Parameters
featureCountintThe number of features in the input data.
Returns
- Task
A task representing the asynchronous operation.
Remarks
For Beginners: This method figures out how much each input feature contributes to the model's predictions. It helps you understand which pieces of information are most useful for making accurate predictions.
For example, in predicting house prices:
- It might find that location is very important
- While the house's paint color is less important
This can help you focus on collecting the most relevant data for future predictions.
Clone()
Creates a clone of the decision tree model.
public virtual IFullModel<T, Matrix<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and tree structure.
Remarks
This method creates a complete copy of the asynchronous decision tree model, including the entire tree structure and all learned parameters. Specific implementations should override this method to ensure all implementation-specific properties are properly copied.
For Beginners: This method creates an exact independent copy of your model.
Cloning a model means creating a new model that's exactly the same as the original, including all its learned parameters and settings. However, the clone is independent - changes to one model won't affect the other.
Think of it like photocopying a document - the copy has all the same information, but you can mark up the copy without changing the original.
Note: Specific decision tree algorithms will customize this method to ensure all their unique properties are properly copied.
ComputeGradients(Matrix<T>, Vector<T>, ILossFunction<T>?)
Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.
public virtual Vector<T> ComputeGradients(Matrix<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputMatrix<T>The input data.
targetVector<T>The target/expected output.
lossFunctionILossFunction<T>The loss function to use for gradient computation. If null, uses the model's default loss function.
Returns
- Vector<T>
A vector containing gradients with respect to all model parameters.
Remarks
This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.
Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.
For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.
For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)
Exceptions
- InvalidOperationException
If lossFunction is null and the model has no default loss function.
CreateNewInstance()
Creates a new instance of the async decision tree model with the same options.
protected abstract IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same configuration but no trained parameters.
Remarks
This abstract method must be implemented by derived classes to create a new instance of the specific decision tree model type with the same configuration options but without copying the trained parameters.
For Beginners: This method creates a fresh copy of the model configuration without any learned parameters.
Think of it like getting a blank notepad with the same paper quality and size, but without any writing on it yet. The new model has the same:
- Maximum depth setting
- Minimum samples split setting
- Split criterion (how nodes decide which feature to split on)
- Other configuration options
But it doesn't have any of the actual tree structure that was learned from data.
This is mainly used internally when doing things like cross-validation or creating ensembles of similar models with different training data.
DeepCopy()
Creates a deep copy of the decision tree model.
public virtual IFullModel<T, Matrix<T>, Vector<T>> DeepCopy()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and tree structure.
Remarks
This method creates a complete copy of the asynchronous decision tree model, including all nodes, connections, and learned parameters. The copy is independent of the original model, so modifications to one don't affect the other.
For Beginners: This method creates an exact independent copy of your model.
The copy has the same:
- Tree structure (all nodes and their connections)
- Decision rules (which features to split on and at what values)
- Prediction values at leaf nodes
- Feature importance scores
But it's completely separate from the original model - changes to one won't affect the other.
This is useful when you want to:
- Experiment with modifying a model without affecting the original
- Create multiple similar models to use in different contexts
- Save a "checkpoint" of your model before making changes
Deserialize(byte[])
Deserializes the model from a byte array.
public virtual void Deserialize(byte[] modelData)
Parameters
modelDatabyte[]The byte array containing the serialized model data.
Remarks
For Beginners: Deserialization is the process of reconstructing the model from saved data. It's like unpacking the model from the suitcase you packed it in earlier.
This method:
- Reads the saved settings and restores them
- Rebuilds the decision tree structure
- Sets up the feature importances
After calling this method, your model will be ready to use for making predictions, just like it was before you serialized it.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the model's computation graph for JIT compilation.
public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The root node of the exported computation graph.
Remarks
When soft tree mode is enabled, this exports the tree as a differentiable computation graph using SoftSplit(ComputationNode<T>, ComputationNode<T>, ComputationNode<T>, int, T, T?) operations. Each internal node becomes a soft split operation that computes sigmoid-weighted combinations of left and right subtree outputs.
For Beginners: This method converts the decision tree into a computation graph.
In soft tree mode, each decision node becomes a smooth blend:
- Instead of "go left OR right", it computes "X% left + Y% right"
- The percentages are determined by the sigmoid function
- This creates a smooth, differentiable function that can be JIT compiled
Exceptions
- NotSupportedException
Thrown when UseSoftTree is false.
- InvalidOperationException
Thrown when the tree has not been trained (Root is null).
GetActiveFeatureIndices()
Gets the indices of all features that are used in the decision tree.
public virtual IEnumerable<int> GetActiveFeatureIndices()
Returns
- IEnumerable<int>
An enumerable collection of indices for features used in the tree.
Remarks
This method identifies all features that are used as split criteria in the decision tree. Features that don't appear in any decision node are not considered active.
For Beginners: This method tells you which input features are actually used in the tree.
Decision trees often don't use all available features - they select the most informative ones during training. This method returns the positions (indices) of features that are actually used in decision nodes throughout the tree.
For example, if your dataset has 10 features but the tree only uses features at positions 2, 5, and 7, this method would return [2, 5, 7].
This is useful for:
- Feature selection (identifying which features matter)
- Model simplification (removing unused features)
- Understanding which inputs actually affect the prediction
GetFeatureImportance()
Gets the feature importance scores as a dictionary.
public virtual Dictionary<string, T> GetFeatureImportance()
Returns
- Dictionary<string, T>
A dictionary mapping feature names to their importance scores.
GetModelMetadata()
Gets metadata about the trained model.
public abstract ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing information about the model.
Remarks
For Beginners: Model metadata is information about the trained model itself, not about the predictions it makes. This can include things like:
- How well the model performs
- How complex the model is
- What settings were used to train it
It's like getting a report card for your model, showing how well it learned and what it learned.
GetParameters()
Gets the model parameters as a vector representation.
public virtual Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing a serialized representation of the decision tree structure.
Remarks
This method provides a vector representation of the asynchronous decision tree model. Decision trees have a hierarchical structure that doesn't naturally fit into a flat vector format, so this representation is a simplified encoding of the tree structure suitable for certain optimization algorithms or model comparison techniques.
For Beginners: This method converts the tree structure into a flat list of numbers.
Decision trees are complex structures with branches and nodes, which don't naturally fit into a simple list of parameters like linear models do. This method creates a specialized representation of the tree that can be used by certain algorithms or for model comparison.
The exact format of this representation depends on the specific implementation, but generally includes information about:
- Each node's feature index (which feature it splits on)
- Each node's split value (the threshold for the decision)
- Each node's prediction value (for leaf nodes)
- The tree structure (how nodes connect to each other)
This is primarily used by advanced algorithms and not typically needed for regular use.
IsFeatureUsed(int)
Determines whether a specific feature is used in the decision tree.
public virtual bool IsFeatureUsed(int featureIndex)
Parameters
featureIndexintThe zero-based index of the feature to check.
Returns
- bool
True if the feature is used in at least one decision node; otherwise, false.
Remarks
This method checks whether a specific feature is used as a split criterion in any node of the decision tree.
For Beginners: This method checks if a specific input feature is used in the tree.
You provide the position (index) of a feature, and the method tells you whether that feature is used in any decision node throughout the tree.
For example, if feature #3 is never used to make a decision in the tree, this method would return false because that feature doesn't affect the model's predictions.
This is useful when you want to check a specific feature's importance rather than getting all important features at once.
LoadModel(string)
Loads the model from a file.
public virtual void LoadModel(string filePath)
Parameters
filePathstringThe path from which to load the model.
LoadState(Stream)
Loads the model's state from a stream.
public virtual void LoadState(Stream stream)
Parameters
streamStream
Predict(Matrix<T>)
Makes predictions using the trained model.
public Vector<T> Predict(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix to make predictions for.
Returns
- Vector<T>
A vector of predictions.
Remarks
For Beginners: This method is a synchronous wrapper around the asynchronous PredictAsync method. It does the same thing as PredictAsync, but it waits for the predictions to be made before moving on.
Use this method when you want to get predictions immediately and wait for them to be ready. It's like asking a question and waiting for the answer before you do anything else.
PredictAsync(Matrix<T>)
Asynchronously makes predictions using the trained model.
public abstract Task<Vector<T>> PredictAsync(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix to make predictions for.
Returns
- Task<Vector<T>>
A task representing the asynchronous operation, which resolves to a vector of predictions.
Remarks
For Beginners: Prediction is using the trained model to estimate outcomes for new data. It's like using what the model learned to make educated guesses about new situations.
For example, if you trained a model on house features and prices:
- Input: Features of a new house (size, location, etc.)
- Output: The model's estimate of what that house might cost
SaveModel(string)
Saves the model to a file.
public virtual void SaveModel(string filePath)
Parameters
filePathstringThe path where the model should be saved.
SaveState(Stream)
Saves the model's current state to a stream.
public virtual void SaveState(Stream stream)
Parameters
streamStream
Serialize()
Serializes the model to a byte array.
public virtual byte[] Serialize()
Returns
- byte[]
A byte array representing the serialized model.
Remarks
For Beginners: Serialization is the process of converting the model into a format that can be easily stored or transmitted. It's like packing up the model into a suitcase so you can take it with you or save it for later.
This method saves:
- The model's settings (like max depth)
- The importance of each feature
- The structure of the decision tree
You can use this to save your trained model and load it later without having to retrain.
SetActiveFeatureIndices(IEnumerable<int>)
Sets the active feature indices for this model.
public virtual void SetActiveFeatureIndices(IEnumerable<int> featureIndices)
Parameters
featureIndicesIEnumerable<int>The indices of features to activate.
SetParameters(Vector<T>)
Sets the parameters for this model.
public virtual void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing the model parameters.
Train(Matrix<T>, Vector<T>)
Trains the decision tree model on the provided data.
public void Train(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>The input features matrix.
yVector<T>The target values vector.
Remarks
For Beginners: This method is a synchronous wrapper around the asynchronous TrainAsync method. It does the same thing as TrainAsync, but it waits for the training to complete before moving on.
Use this method when you want to train the model and wait for it to finish before doing anything else. It's like waiting for a cake to finish baking before you start decorating it.
TrainAsync(Matrix<T>, Vector<T>)
Asynchronously trains the decision tree model on the provided data.
public abstract Task TrainAsync(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>The input features matrix.
yVector<T>The target values vector.
Returns
- Task
A task representing the asynchronous operation.
Remarks
For Beginners: Training is the process where the model learns from the data you provide. It's like teaching the model to recognize patterns in your data.
- x: This is your input data, like the features of houses (size, location, etc.)
- y: This is what you're trying to predict, like house prices
After training, the model will have learned how to use the features to predict the target values.
WithParameters(Vector<T>)
Creates a new instance of the model with the specified parameters.
public virtual IFullModel<T, Matrix<T>, Vector<T>> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing a serialized representation of the decision tree structure.
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new model instance with the reconstructed tree structure.
Remarks
This method reconstructs a decision tree model from a parameter vector that was previously created using the GetParameters method. Due to the complex nature of tree structures, this reconstruction is approximate and is primarily intended for use with optimization algorithms or model comparison techniques.
For Beginners: This method rebuilds a decision tree from a flat list of numbers.
It takes the specialized vector representation created by GetParameters() and attempts to reconstruct a decision tree from it. This is challenging because decision trees are complex structures that don't easily convert to and from simple lists of numbers.
This method is primarily used by advanced algorithms and not typically needed for regular use. For most purposes, the Serialize and Deserialize methods provide a more reliable way to save and load tree models.