Table of Contents

Class DecisionTreeClassifier<T>

Namespace
AiDotNet.Classification.Trees
Assembly
AiDotNet.dll

A decision tree classifier that learns a hierarchy of decision rules from training data.

public class DecisionTreeClassifier<T> : ProbabilisticClassifierBase<T>, ITreeBasedClassifier<T>, IProbabilisticClassifier<T>, IClassifier<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T

The numeric data type used for calculations (e.g., float, double).

Inheritance
DecisionTreeClassifier<T>
Implements
IFullModel<T, Matrix<T>, Vector<T>>
IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Matrix<T>, Vector<T>>
ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>
IGradientComputable<T, Matrix<T>, Vector<T>>
Inherited Members
Extension Methods

Remarks

Decision trees are non-parametric supervised learning algorithms that learn decision rules inferred from data features. They partition the feature space into regions and assign class labels to each region.

For Beginners: Imagine playing a game of "20 Questions" to classify things. The decision tree learns which questions (based on features) best separate the different classes.

Example: Classifying whether to play tennis

  1. Is it raining? -> No: Go to step 2, Yes: Don't play
  2. Is humidity > 75%? -> No: Play!, Yes: Don't play

Each question splits the data based on a feature value, and leaves contain the final decisions.

Constructors

DecisionTreeClassifier(DecisionTreeClassifierOptions<T>?, IRegularization<T, Matrix<T>, Vector<T>>?)

Initializes a new instance of the DecisionTreeClassifier class.

public DecisionTreeClassifier(DecisionTreeClassifierOptions<T>? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)

Parameters

options DecisionTreeClassifierOptions<T>

Configuration options for the decision tree.

regularization IRegularization<T, Matrix<T>, Vector<T>>

Optional regularization strategy.

Properties

FeatureImportances

Gets the feature importance scores computed during training.

public Vector<T>? FeatureImportances { get; }

Property Value

Vector<T>

A vector of importance scores, one for each feature. Higher values indicate more important features. Returns null if the model has not been trained.

Remarks

Feature importance is typically computed based on how much each feature contributes to reducing impurity (e.g., Gini impurity or entropy) in the tree.

For Beginners: This tells you which features the tree found most useful for making decisions. A high importance score means that feature appears often near the top of the tree and is crucial for classification.

LeafCount

Gets the number of leaf nodes in the tree.

public int LeafCount { get; }

Property Value

int

The count of terminal nodes (leaves) in the trained tree. Returns 0 if the model has not been trained.

MaxDepth

Gets the maximum depth of the tree.

public int MaxDepth { get; }

Property Value

int

The maximum depth reached during training, or the configured maximum depth.

NodeCount

Gets the number of internal (decision) nodes in the tree.

public int NodeCount { get; }

Property Value

int

The count of non-terminal nodes that make decisions. Returns 0 if the model has not been trained.

Options

Gets the decision tree specific options.

protected DecisionTreeClassifierOptions<T> Options { get; }

Property Value

DecisionTreeClassifierOptions<T>

Methods

ApplyGradients(Vector<T>, T)

Applies pre-computed gradients to update the model parameters.

public override void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>

The gradient vector to apply.

learningRate T

The learning rate for the update.

Remarks

Updates parameters using: θ = θ - learningRate * gradients

For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.

Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.

Clone()

Creates a clone of the classifier model.

public override IFullModel<T, Matrix<T>, Vector<T>> Clone()

Returns

IFullModel<T, Matrix<T>, Vector<T>>

A new instance of the model with the same parameters and options.

ComputeGradients(Matrix<T>, Vector<T>, ILossFunction<T>?)

Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.

public override Vector<T> ComputeGradients(Matrix<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Matrix<T>

The input data.

target Vector<T>

The target/expected output.

lossFunction ILossFunction<T>

The loss function to use for gradient computation. If null, uses the model's default loss function.

Returns

Vector<T>

A vector containing gradients with respect to all model parameters.

Remarks

This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.

Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.

For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.

For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)

Exceptions

InvalidOperationException

If lossFunction is null and the model has no default loss function.

CreateNewInstance()

Creates a new instance of the same type as this classifier.

protected override IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()

Returns

IFullModel<T, Matrix<T>, Vector<T>>

A new instance of the same classifier type.

Deserialize(byte[])

Deserializes the model from a byte array.

public override void Deserialize(byte[] modelData)

Parameters

modelData byte[]

The byte array containing the serialized model data.

Remarks

This method reconstructs the model's parameters from a serialized byte array.

For Beginners: Deserialization is the opposite of serialization - it takes the saved model data and reconstructs the model's internal state. This allows you to load a previously trained model and use it to make predictions without having to retrain it.

Exceptions

InvalidOperationException

Thrown when deserialization fails.

GetModelMetadata()

Gets metadata about the model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetadata object containing information about the model.

Remarks

This method returns metadata about the model, including its type, feature count, complexity, description, and additional information specific to classification.

For Beginners: Model metadata provides information about the model itself, rather than the predictions it makes. This includes details about the model's structure (like how many features it uses) and characteristics (like how many classes it can predict). This information can be useful for understanding and comparing different models.

GetModelType()

Returns the model type identifier for this classifier.

protected override ModelType GetModelType()

Returns

ModelType

GetParameters()

Gets all model parameters as a single vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

A vector containing all model parameters.

Remarks

This method returns a vector containing all model parameters for use with optimization algorithms or model comparison.

For Beginners: This method packages all the model's parameters into a single collection. This is useful for optimization algorithms that need to work with all parameters at once.

PredictProbabilities(Matrix<T>)

Predicts class probabilities for each sample in the input.

public override Matrix<T> PredictProbabilities(Matrix<T> input)

Parameters

input Matrix<T>

The input features matrix where each row is a sample and each column is a feature.

Returns

Matrix<T>

A matrix where each row corresponds to an input sample and each column corresponds to a class. The values represent the probability of the sample belonging to each class.

Remarks

This abstract method must be implemented by derived classes to compute class probabilities. The output matrix should have shape [num_samples, num_classes], and each row should sum to 1.0.

For Beginners: This method computes the probability of each sample belonging to each class. Each row in the output represents one sample, and each column represents one class. The values in each row sum to 1.0 (100% total probability).

Serialize()

Serializes the model to a byte array.

public override byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method serializes the model's parameters to a JSON format and then converts it to a byte array.

For Beginners: Serialization converts the model's internal state into a format that can be saved to disk or transmitted over a network. This allows you to save a trained model and load it later without having to retrain it.

SetParameters(Vector<T>)

Sets the parameters for this model.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all model parameters.

Exceptions

ArgumentException

Thrown when the parameters vector has an incorrect length.

Train(Matrix<T>, Vector<T>)

Trains the decision tree on the provided data.

public override void Train(Matrix<T> x, Vector<T> y)

Parameters

x Matrix<T>
y Vector<T>

WithParameters(Vector<T>)

Creates a new instance of the model with specified parameters.

public override IFullModel<T, Matrix<T>, Vector<T>> WithParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all model parameters.

Returns

IFullModel<T, Matrix<T>, Vector<T>>

A new model instance with the specified parameters.

Exceptions

ArgumentException

Thrown when the parameters vector has an incorrect length.