Class ClassifierBase<T>
- Namespace
- AiDotNet.Classification
- Assembly
- AiDotNet.dll
Provides a base implementation for classification algorithms that predict categorical outcomes.
public abstract class ClassifierBase<T> : IClassifier<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric data type used for calculations (e.g., float, double).
- Inheritance
-
ClassifierBase<T>
- Implements
-
IClassifier<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
This abstract class implements common functionality for classification models, including prediction, serialization/deserialization, and parameter management. Specific classification algorithms should inherit from this class and implement the Train and Predict methods.
The class supports various options like class weighting to handle imbalanced datasets and different classification task types (binary, multi-class, multi-label, ordinal).
For Beginners: Classification is about predicting which category something belongs to. This base class provides the foundation for different classification techniques, handling common operations like making predictions and saving/loading models. Think of it as a template that specific classification algorithms can customize while reusing the shared functionality.
Constructors
ClassifierBase(ClassifierOptions<T>?, IRegularization<T, Matrix<T>, Vector<T>>?, ILossFunction<T>?)
Initializes a new instance of the ClassifierBase class with the specified options and regularization.
protected ClassifierBase(ClassifierOptions<T>? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null, ILossFunction<T>? lossFunction = null)
Parameters
optionsClassifierOptions<T>Configuration options for the classifier model. If null, default options will be used.
regularizationIRegularization<T, Matrix<T>, Vector<T>>Regularization method to prevent overfitting. If null, no regularization will be applied.
lossFunctionILossFunction<T>Loss function for gradient computation. If null, defaults to Cross Entropy.
Remarks
The constructor initializes the model with either the provided options or default settings.
For Beginners: This constructor sets up the classification model with your specified settings or uses default settings if none are provided. Regularization is an optional technique to prevent the model from becoming too complex and overfitting to the training data. The loss function determines how prediction errors are measured during training.
Properties
ClassLabels
Gets or sets the class labels learned during training.
public Vector<T>? ClassLabels { get; protected set; }
Property Value
- Vector<T>
A vector containing the unique class labels, or null if not yet trained.
DefaultLossFunction
Gets the default loss function used by this model for gradient computation.
public virtual ILossFunction<T> DefaultLossFunction { get; }
Property Value
Remarks
This loss function is used when calling ComputeGradients(TInput, TOutput, ILossFunction<T>?) without explicitly providing a loss function. It represents the model's primary training objective.
For Beginners: The loss function tells the model "what counts as a mistake". For example: - For regression (predicting numbers): Mean Squared Error measures how far predictions are from actual values - For classification (predicting categories): Cross Entropy measures how confident the model is in the right category
This property provides a sensible default so you don't have to specify the loss function every time, but you can still override it if needed for special cases.
Distributed Training: In distributed training, all workers use the same loss function to ensure consistent gradient computation. The default loss function is automatically used when workers compute local gradients.
Exceptions
- InvalidOperationException
Thrown if accessed before the model has been configured with a loss function.
Engine
Gets the global execution engine for vector operations.
protected IEngine Engine { get; }
Property Value
- IEngine
Remarks
This property provides access to the execution engine (CPU or GPU) for performing vectorized operations. The engine is determined by the global AiDotNetEngine configuration and allows automatic fallback from GPU to CPU when GPU is not available.
For Beginners: This gives access to either CPU or GPU processing for faster computations. The system automatically chooses the best available option and falls back to CPU if GPU acceleration is not available.
ExpectedParameterCount
Gets the expected number of parameters for this model.
protected virtual int ExpectedParameterCount { get; }
Property Value
- int
The total number of parameters in the model, used for serialization and gradient computation.
FeatureNames
Gets or sets the feature names.
public string[]? FeatureNames { get; set; }
Property Value
- string[]
An array of feature names. If not set, feature indices will be used as names.
NumClasses
Gets or sets the number of classes in the classification problem.
public int NumClasses { get; protected set; }
Property Value
- int
The number of distinct classes learned during training.
NumFeatures
Gets or sets the number of features expected by this classifier.
public int NumFeatures { get; protected set; }
Property Value
- int
The number of input features the model was trained on.
NumOps
Gets the numeric operations for the specified type T.
protected INumericOperations<T> NumOps { get; }
Property Value
- INumericOperations<T>
An object that provides mathematical operations for type T.
Options
Gets the classifier options.
protected ClassifierOptions<T> Options { get; }
Property Value
- ClassifierOptions<T>
Configuration options for the classifier model.
ParameterCount
Gets the total number of parameters in the model.
public virtual int ParameterCount { get; }
Property Value
Regularization
Gets the regularization method used to prevent overfitting.
protected IRegularization<T, Matrix<T>, Vector<T>> Regularization { get; }
Property Value
- IRegularization<T, Matrix<T>, Vector<T>>
An object that implements regularization for the classifier model.
SupportsJitCompilation
Gets whether this model currently supports JIT compilation.
public virtual bool SupportsJitCompilation { get; }
Property Value
- bool
True if the model can be JIT compiled, false otherwise.
Remarks
Some models may not support JIT compilation due to: - Dynamic graph structure (changes based on input) - Lack of computation graph representation - Use of operations not yet supported by the JIT compiler
For Beginners: This tells you whether this specific model can benefit from JIT compilation.
Models return false if they:
- Use layer-based architecture without graph export (e.g., current neural networks)
- Have control flow that changes based on input data
- Use operations the JIT compiler doesn't understand yet
In these cases, the model will still work normally, just without JIT acceleration.
TaskType
Gets or sets the type of classification task.
public ClassificationTaskType TaskType { get; protected set; }
Property Value
- ClassificationTaskType
The classification task type (Binary, MultiClass, MultiLabel, or Ordinal).
Methods
ApplyGradients(Vector<T>, T)
Applies pre-computed gradients to update the model parameters.
public abstract void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>The gradient vector to apply.
learningRateTThe learning rate for the update.
Remarks
Updates parameters using: θ = θ - learningRate * gradients
For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.
Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.
Clone()
Creates a clone of the classifier model.
public virtual IFullModel<T, Matrix<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and options.
ComputeClassWeights(Vector<T>)
Computes class weights for handling imbalanced datasets.
protected virtual double[] ComputeClassWeights(Vector<T> y)
Parameters
yVector<T>The target labels vector.
Returns
- double[]
An array of weights for each class.
ComputeGradients(Matrix<T>, Vector<T>, ILossFunction<T>?)
Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.
public abstract Vector<T> ComputeGradients(Matrix<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputMatrix<T>The input data.
targetVector<T>The target/expected output.
lossFunctionILossFunction<T>The loss function to use for gradient computation. If null, uses the model's default loss function.
Returns
- Vector<T>
A vector containing gradients with respect to all model parameters.
Remarks
This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.
Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.
For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.
For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)
Exceptions
- InvalidOperationException
If lossFunction is null and the model has no default loss function.
CreateNewInstance()
Creates a new instance of the same type as this classifier.
protected abstract IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the same classifier type.
DeepCopy()
Creates a deep copy of the classifier model.
public virtual IFullModel<T, Matrix<T>, Vector<T>> DeepCopy()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and options.
Deserialize(byte[])
Deserializes the model from a byte array.
public virtual void Deserialize(byte[] modelData)
Parameters
modelDatabyte[]The byte array containing the serialized model data.
Remarks
This method reconstructs the model's parameters from a serialized byte array.
For Beginners: Deserialization is the opposite of serialization - it takes the saved model data and reconstructs the model's internal state. This allows you to load a previously trained model and use it to make predictions without having to retrain it.
Exceptions
- InvalidOperationException
Thrown when deserialization fails.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the model's computation graph for JIT compilation.
public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes (parameters).
Returns
- ComputationNode<T>
The output computation node representing the model's prediction.
Remarks
This method should construct a computation graph representing the model's forward pass. The graph should use placeholder input nodes that will be filled with actual data during execution.
For Beginners: This method creates a "recipe" of your model's calculations that the JIT compiler can optimize.
The method should:
- Create placeholder nodes for inputs (features, parameters)
- Build the computation graph using TensorOperations
- Return the final output node
- Add all input nodes to the inputNodes list (in order)
Example for a simple linear model (y = Wx + b):
public ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
{
// Create placeholder inputs
var x = TensorOperations<T>.Variable(new Tensor<T>(InputShape), "x");
var W = TensorOperations<T>.Variable(Weights, "W");
var b = TensorOperations<T>.Variable(Bias, "b");
// Add inputs in order
inputNodes.Add(x);
inputNodes.Add(W);
inputNodes.Add(b);
// Build graph: y = Wx + b
var matmul = TensorOperations<T>.MatMul(x, W);
var output = TensorOperations<T>.Add(matmul, b);
return output;
}
The JIT compiler will then:
- Optimize the graph (fuse operations, eliminate dead code)
- Compile it to fast native code
- Cache the compiled version for reuse
ExtractClassLabels(Vector<T>)
Extracts unique class labels from the training data.
protected virtual Vector<T> ExtractClassLabels(Vector<T> y)
Parameters
yVector<T>The target labels vector.
Returns
- Vector<T>
A sorted vector of unique class labels.
GetActiveFeatureIndices()
Gets the indices of features that are actively used in the model.
public virtual IEnumerable<int> GetActiveFeatureIndices()
Returns
- IEnumerable<int>
An enumerable collection of indices for features that contribute to predictions.
GetClassIndexFromLabel(T)
Gets the class index for a given label value.
protected int GetClassIndexFromLabel(T label)
Parameters
labelTThe label value to look up.
Returns
- int
The index of the class in ClassLabels, or -1 if not found.
Remarks
This method maps label values to their corresponding class indices using the ClassLabels array. This is more robust than directly casting labels to indices, as it handles non-0-indexed labels (e.g., labels like 1, 2, 3 or -1, 1).
GetFeatureImportance()
Gets the feature importance scores as a dictionary.
public virtual Dictionary<string, T> GetFeatureImportance()
Returns
- Dictionary<string, T>
A dictionary mapping feature names to their importance scores.
GetModelMetadata()
Gets metadata about the model.
public virtual ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing information about the model.
Remarks
This method returns metadata about the model, including its type, feature count, complexity, description, and additional information specific to classification.
For Beginners: Model metadata provides information about the model itself, rather than the predictions it makes. This includes details about the model's structure (like how many features it uses) and characteristics (like how many classes it can predict). This information can be useful for understanding and comparing different models.
GetModelType()
Gets the type of the model.
protected abstract ModelType GetModelType()
Returns
- ModelType
The model type identifier.
Remarks
This abstract method must be implemented by derived classes to specify the model type.
For Beginners: This method simply returns an identifier that indicates what type of classifier this is (e.g., Naive Bayes, Random Forest). It's used internally by the library to keep track of different types of models.
GetParameters()
Gets all model parameters as a single vector.
public abstract Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all model parameters.
Remarks
This method returns a vector containing all model parameters for use with optimization algorithms or model comparison.
For Beginners: This method packages all the model's parameters into a single collection. This is useful for optimization algorithms that need to work with all parameters at once.
InferTaskType(Vector<T>)
Infers the classification task type from the training labels.
protected virtual ClassificationTaskType InferTaskType(Vector<T> y)
Parameters
yVector<T>The target labels vector.
Returns
- ClassificationTaskType
The inferred classification task type.
Remarks
This method examines the unique values in the target vector to determine whether this is a binary or multi-class classification problem.
For Beginners: The model automatically figures out what kind of classification problem you have: - If there are exactly 2 unique values → Binary classification - If there are more than 2 unique values → Multi-class classification
IsFeatureUsed(int)
Determines whether a specific feature is used in the model.
public virtual bool IsFeatureUsed(int featureIndex)
Parameters
featureIndexintThe zero-based index of the feature to check.
Returns
- bool
True if the feature contributes to predictions; otherwise, false.
Exceptions
- ArgumentOutOfRangeException
Thrown when the feature index is outside the valid range.
LoadModel(string)
Loads a classifier model from a file.
public virtual void LoadModel(string filePath)
Parameters
filePathstringThe path to the file containing the saved model.
LoadState(Stream)
Loads the model's state from a stream.
public virtual void LoadState(Stream stream)
Parameters
streamStreamThe stream to read the model state from.
Predict(Matrix<T>)
Predicts class labels for the given input data.
public abstract Vector<T> Predict(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is an example and each column is a feature.
Returns
- Vector<T>
A vector of predicted class indices for each input example.
Remarks
This method calculates predictions for each sample in the input matrix. The returned vector contains class indices (0, 1, 2, ...).
For Beginners: After training, this method is used to make predictions on new data. It returns the predicted class for each input sample as a numeric index. Use ClassLabels to map these indices back to the original label values if needed.
SaveModel(string)
Saves the classifier model to a file.
public virtual void SaveModel(string filePath)
Parameters
filePathstringThe path where the model should be saved.
SaveState(Stream)
Saves the model's current state to a stream.
public virtual void SaveState(Stream stream)
Parameters
streamStreamThe stream to write the model state to.
Serialize()
Serializes the model to a byte array.
public virtual byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model data.
Remarks
This method serializes the model's parameters to a JSON format and then converts it to a byte array.
For Beginners: Serialization converts the model's internal state into a format that can be saved to disk or transmitted over a network. This allows you to save a trained model and load it later without having to retrain it.
SetActiveFeatureIndices(IEnumerable<int>)
Sets the active feature indices for this model.
public virtual void SetActiveFeatureIndices(IEnumerable<int> featureIndices)
Parameters
featureIndicesIEnumerable<int>The indices of features to activate.
SetParameters(Vector<T>)
Sets the parameters for this model.
public abstract void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all model parameters.
Exceptions
- ArgumentException
Thrown when the parameters vector has an incorrect length.
Train(Matrix<T>, Vector<T>)
Trains the classifier on the provided data.
public abstract void Train(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>The input features matrix where each row is a training example and each column is a feature.
yVector<T>The target class labels vector corresponding to each training example.
Remarks
This abstract method must be implemented by derived classes to train the classifier. The target vector should contain class indices (0, 1, 2, ...) for each sample.
For Beginners: Training is the process where the model learns from your data. Different classification algorithms implement this method differently, but they all aim to learn how to correctly predict the class labels based on the input features.
WithParameters(Vector<T>)
Creates a new instance of the model with specified parameters.
public abstract IFullModel<T, Matrix<T>, Vector<T>> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all model parameters.
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new model instance with the specified parameters.
Exceptions
- ArgumentException
Thrown when the parameters vector has an incorrect length.