Class LinearClassifierBase<T>
- Namespace
- AiDotNet.Classification.Linear
- Assembly
- AiDotNet.dll
Provides a base implementation for linear classifiers.
public abstract class LinearClassifierBase<T> : ProbabilisticClassifierBase<T>, IProbabilisticClassifier<T>, IClassifier<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric data type used for calculations (e.g., float, double).
- Inheritance
-
LinearClassifierBase<T>
- Implements
-
IClassifier<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
Linear classifiers learn a linear decision function: f(x) = w * x + b where w is the weight vector and b is the bias (intercept).
For Beginners: Linear classifiers are one of the simplest forms of machine learning:
How they work:
- Each feature gets a weight (importance score)
- Multiply each feature by its weight and sum them up
- Add a bias term
- If the result is positive, predict one class; otherwise, the other
The training process adjusts the weights to correctly classify training examples.
Advantages:
- Fast to train and predict
- Work well with many features
- Easy to interpret (weight = feature importance)
- Often surprisingly effective
Constructors
LinearClassifierBase(LinearClassifierOptions<T>?, IRegularization<T, Matrix<T>, Vector<T>>?)
Initializes a new instance of the LinearClassifierBase class.
protected LinearClassifierBase(LinearClassifierOptions<T>? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)
Parameters
optionsLinearClassifierOptions<T>Configuration options for the linear classifier.
regularizationIRegularization<T, Matrix<T>, Vector<T>>Optional regularization strategy.
Properties
Intercept
The learned intercept (bias) term.
protected T Intercept { get; set; }
Property Value
- T
Options
Gets the linear classifier specific options.
protected LinearClassifierOptions<T> Options { get; }
Property Value
Random
Random number generator for shuffling.
protected Random? Random { get; set; }
Property Value
Weights
The learned weight vector.
protected Vector<T>? Weights { get; set; }
Property Value
- Vector<T>
Methods
ApplyGradients(Vector<T>, T)
Applies pre-computed gradients to update the model parameters.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>The gradient vector to apply.
learningRateTThe learning rate for the update.
Remarks
Updates parameters using: θ = θ - learningRate * gradients
For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.
Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.
ApplyL1Gradient(T, T)
Applies L1 regularization gradient to the weights.
protected void ApplyL1Gradient(T learningRate, T alpha)
Parameters
learningRateTalphaT
ApplyL2Gradient(T, T)
Applies L2 regularization gradient to the weights.
protected void ApplyL2Gradient(T learningRate, T alpha)
Parameters
learningRateTalphaT
ComputeGradients(Matrix<T>, Vector<T>, ILossFunction<T>?)
Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.
public override Vector<T> ComputeGradients(Matrix<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputMatrix<T>The input data.
targetVector<T>The target/expected output.
lossFunctionILossFunction<T>The loss function to use for gradient computation. If null, uses the model's default loss function.
Returns
- Vector<T>
A vector containing gradients with respect to all model parameters.
Remarks
This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.
Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.
For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.
For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)
Exceptions
- InvalidOperationException
If lossFunction is null and the model has no default loss function.
DecisionFunction(Vector<T>)
Computes the decision function value for a single sample.
protected T DecisionFunction(Vector<T> sample)
Parameters
sampleVector<T>The feature vector for a single sample.
Returns
- T
The decision function value (w * x + b).
DecisionFunctionBatch(Matrix<T>)
Computes decision function values for all samples.
public Vector<T> DecisionFunctionBatch(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix.
Returns
- Vector<T>
A vector of decision function values.
GetModelMetadata()
Gets metadata about the model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetadata object containing information about the model.
Remarks
This method returns metadata about the model, including its type, feature count, complexity, description, and additional information specific to classification.
For Beginners: Model metadata provides information about the model itself, rather than the predictions it makes. This includes details about the model's structure (like how many features it uses) and characteristics (like how many classes it can predict). This information can be useful for understanding and comparing different models.
GetParameters()
Gets all model parameters as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all model parameters.
Remarks
This method returns a vector containing all model parameters for use with optimization algorithms or model comparison.
For Beginners: This method packages all the model's parameters into a single collection. This is useful for optimization algorithms that need to work with all parameters at once.
InitializeWeights()
Initializes the weights before training.
protected virtual void InitializeWeights()
Predict(Matrix<T>)
Predicts class labels for the given input data by taking the argmax of probabilities.
public override Vector<T> Predict(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is an example and each column is a feature.
Returns
- Vector<T>
A vector of predicted class indices for each input example.
Remarks
This implementation uses the argmax of the probability distribution to determine the predicted class. For binary classification with a custom decision threshold, you may want to use PredictProbabilities() directly and apply your own threshold.
For Beginners: This method picks the class with the highest probability for each sample.
For example, if the probabilities are [0.1, 0.7, 0.2] for classes [A, B, C], this method returns class B (index 1) because it has the highest probability (0.7).
PredictLogProbabilities(Matrix<T>)
Predicts log-probabilities for each class.
public override Matrix<T> PredictLogProbabilities(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is a sample and each column is a feature.
Returns
- Matrix<T>
A matrix where each row corresponds to an input sample and each column corresponds to a class. The values are the natural logarithm of the class probabilities.
Remarks
The default implementation computes log(PredictProbabilities(input)). Subclasses that compute log-probabilities directly (like Naive Bayes) should override this method for better numerical stability.
For Beginners: Log-probabilities are probabilities transformed by the natural logarithm. They're useful for numerical stability when working with very small probabilities.
For example:
- Probability 0.9 → Log-probability -0.105
- Probability 0.1 → Log-probability -2.303
- Probability 0.001 → Log-probability -6.908
Log-probabilities are always negative (since probabilities are between 0 and 1). Higher (less negative) values mean higher probability.
PredictProbabilities(Matrix<T>)
Predicts class probabilities for each sample in the input.
public override Matrix<T> PredictProbabilities(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is a sample and each column is a feature.
Returns
- Matrix<T>
A matrix where each row corresponds to an input sample and each column corresponds to a class. The values represent the probability of the sample belonging to each class.
Remarks
This abstract method must be implemented by derived classes to compute class probabilities. The output matrix should have shape [num_samples, num_classes], and each row should sum to 1.0.
For Beginners: This method computes the probability of each sample belonging to each class. Each row in the output represents one sample, and each column represents one class. The values in each row sum to 1.0 (100% total probability).
SetParameters(Vector<T>)
Sets the parameters for this model.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all model parameters.
Exceptions
- ArgumentException
Thrown when the parameters vector has an incorrect length.
ShuffleIndices(int)
Shuffles the training data indices.
protected int[] ShuffleIndices(int n)
Parameters
nint
Returns
- int[]
Sigmoid(T)
Computes the sigmoid function.
protected T Sigmoid(T x)
Parameters
xT
Returns
- T
WithParameters(Vector<T>)
Creates a new instance of the model with specified parameters.
public override IFullModel<T, Matrix<T>, Vector<T>> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all model parameters.
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new model instance with the specified parameters.
Exceptions
- ArgumentException
Thrown when the parameters vector has an incorrect length.