Class ClassifierChain<T>
- Namespace
- AiDotNet.Classification.Meta
- Assembly
- AiDotNet.dll
Classifier Chain for multi-label classification.
public class ClassifierChain<T> : MetaClassifierBase<T>, IProbabilisticClassifier<T>, IMultiLabelClassifier<T>, IClassifier<T>, IFullModel<T, Matrix<T>, Vector<T>>, IModel<Matrix<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Matrix<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Matrix<T>, Vector<T>>>, IGradientComputable<T, Matrix<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric data type used for calculations.
- Inheritance
-
ClassifierChain<T>
- Implements
-
IClassifier<T>
- Inherited Members
- Extension Methods
Remarks
Classifier Chain transforms a multi-label problem into a chain of binary classification problems, where each classifier uses the predictions of previous classifiers as additional features.
For Beginners: Classifier Chain captures label dependencies:
For labels A, B, C:
- Classifier 1: Predict A using features X
- Classifier 2: Predict B using features X + prediction of A
- Classifier 3: Predict C using features X + predictions of A and B
Benefits:
- Captures dependencies between labels
- Better than independent binary classifiers
Trade-offs:
- Order of chain matters (can use random order or learned order)
- Error propagation (early mistakes affect later predictions)
Constructors
ClassifierChain(Func<IClassifier<T>>, ClassifierChainOptions<T>?, IRegularization<T, Matrix<T>, Vector<T>>?)
Initializes a new instance of the ClassifierChain class.
public ClassifierChain(Func<IClassifier<T>> estimatorFactory, ClassifierChainOptions<T>? options = null, IRegularization<T, Matrix<T>, Vector<T>>? regularization = null)
Parameters
estimatorFactoryFunc<IClassifier<T>>Factory function to create base binary classifiers.
optionsClassifierChainOptions<T>Configuration options for the classifier.
regularizationIRegularization<T, Matrix<T>, Vector<T>>Optional regularization strategy.
Properties
NumLabels
Gets the number of labels that can be predicted.
public int NumLabels { get; }
Property Value
Options
Gets the chain-specific options.
protected ClassifierChainOptions<T> Options { get; }
Property Value
Methods
Clone()
Creates a clone of the classifier model.
public override IFullModel<T, Matrix<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the model with the same parameters and options.
CreateNewInstance()
Creates a new instance of the same type as this classifier.
protected override IFullModel<T, Matrix<T>, Vector<T>> CreateNewInstance()
Returns
- IFullModel<T, Matrix<T>, Vector<T>>
A new instance of the same classifier type.
GetModelType()
Returns the model type identifier for this classifier.
protected override ModelType GetModelType()
Returns
Predict(Matrix<T>)
Predicts class labels for the given input data by taking the argmax of probabilities.
public override Vector<T> Predict(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is an example and each column is a feature.
Returns
- Vector<T>
A vector of predicted class indices for each input example.
Remarks
This implementation uses the argmax of the probability distribution to determine the predicted class. For binary classification with a custom decision threshold, you may want to use PredictProbabilities() directly and apply your own threshold.
For Beginners: This method picks the class with the highest probability for each sample.
For example, if the probabilities are [0.1, 0.7, 0.2] for classes [A, B, C], this method returns class B (index 1) because it has the highest probability (0.7).
PredictLogProbabilities(Matrix<T>)
Predicts log-probabilities for each class.
public override Matrix<T> PredictLogProbabilities(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is a sample and each column is a feature.
Returns
- Matrix<T>
A matrix where each row corresponds to an input sample and each column corresponds to a class. The values are the natural logarithm of the class probabilities.
Remarks
The default implementation computes log(PredictProbabilities(input)). Subclasses that compute log-probabilities directly (like Naive Bayes) should override this method for better numerical stability.
For Beginners: Log-probabilities are probabilities transformed by the natural logarithm. They're useful for numerical stability when working with very small probabilities.
For example:
- Probability 0.9 → Log-probability -0.105
- Probability 0.1 → Log-probability -2.303
- Probability 0.001 → Log-probability -6.908
Log-probabilities are always negative (since probabilities are between 0 and 1). Higher (less negative) values mean higher probability.
PredictMultiLabel(Matrix<T>)
Predicts binary indicators for each label for each sample.
public Matrix<T> PredictMultiLabel(Matrix<T> input)
Parameters
inputMatrix<T>The input feature matrix.
Returns
- Matrix<T>
A binary matrix where each row is a sample and each column is a label indicator (1=present, 0=absent).
PredictMultiLabelProbabilities(Matrix<T>)
Predicts probabilities for each label for each sample.
public Matrix<T> PredictMultiLabelProbabilities(Matrix<T> input)
Parameters
inputMatrix<T>The input feature matrix.
Returns
- Matrix<T>
A probability matrix where each row is a sample and each column is the probability of that label.
PredictProbabilities(Matrix<T>)
Predicts class probabilities for each sample in the input.
public override Matrix<T> PredictProbabilities(Matrix<T> input)
Parameters
inputMatrix<T>The input features matrix where each row is a sample and each column is a feature.
Returns
- Matrix<T>
A matrix where each row corresponds to an input sample and each column corresponds to a class. The values represent the probability of the sample belonging to each class.
Remarks
This abstract method must be implemented by derived classes to compute class probabilities. The output matrix should have shape [num_samples, num_classes], and each row should sum to 1.0.
For Beginners: This method computes the probability of each sample belonging to each class. Each row in the output represents one sample, and each column represents one class. The values in each row sum to 1.0 (100% total probability).
Train(Matrix<T>, Vector<T>)
Standard training method - converts single labels to multi-label format.
public override void Train(Matrix<T> x, Vector<T> y)
Parameters
xMatrix<T>yVector<T>
TrainMultiLabel(Matrix<T>, Matrix<T>)
Trains the Classifier Chain on multi-label data.
public void TrainMultiLabel(Matrix<T> x, Matrix<T> yMultiLabel)
Parameters
xMatrix<T>The input features matrix.
yMultiLabelMatrix<T>The multi-label target matrix (rows=samples, cols=labels).