Classification Tutorial
{: .no_toc }
Learn to build classification models with AiDotNet. {: .fs-6 .fw-300 }
Table of contents
{: .no_toc .text-delta }
- TOC {:toc}
What is Classification?
Classification is a supervised learning task where the goal is to predict discrete class labels from input features. Examples include:
- Email spam detection (spam/not spam)
- Image recognition (cat/dog/bird)
- Disease diagnosis (positive/negative)
Types of Classification
Binary Classification
Exactly two classes. Example: Is this email spam?
Multi-class Classification
Three or more classes. Example: What species is this flower?
Multi-label Classification
Each sample can belong to multiple classes. Example: Tag this article with relevant topics.
Quick Start
using AiDotNet;
using AiDotNet.Classification;
// Prepare data
var features = new double[][]
{
new[] { 5.1, 3.5, 1.4, 0.2 }, // Setosa
new[] { 7.0, 3.2, 4.7, 1.4 }, // Versicolor
new[] { 6.3, 3.3, 6.0, 2.5 } // Virginica
};
var labels = new double[] { 0, 1, 2 };
// Build and train
var result = await new AiModelBuilder<double, double[], double>()
.ConfigureModel(new RandomForestClassifier<double>(nEstimators: 100))
.ConfigurePreprocessing()
.ConfigureCrossValidation(new KFoldCrossValidator<double>(k: 5))
.BuildAsync(features, labels);
// Evaluate
Console.WriteLine($"Accuracy: {result.CrossValidationResult?.MeanScore:P2}");
// Predict using result object (facade pattern)
var prediction = result.Predict(new[] { 5.9, 3.0, 5.1, 1.8 });
Console.WriteLine($"Predicted class: {prediction}");
Available Classifiers
Tree-Based Methods
| Classifier | Description | Best For |
|---|---|---|
RandomForestClassifier |
Ensemble of decision trees | General purpose, tabular data |
GradientBoostingClassifier |
Boosted decision trees | High accuracy, Kaggle competitions |
DecisionTreeClassifier |
Single decision tree | Interpretability |
Linear Methods
| Classifier | Description | Best For |
|---|---|---|
LogisticRegression |
Probabilistic linear classifier | Baseline, simple problems |
SVMClassifier |
Support Vector Machine | High-dimensional data |
RidgeClassifier |
L2-regularized linear | Multicollinearity |
Bayesian Methods
| Classifier | Description | Best For |
|---|---|---|
GaussianNaiveBayes |
Continuous features | Fast baseline |
MultinomialNaiveBayes |
Count features | Text classification |
BernoulliNaiveBayes |
Binary features | Document classification |
Distance-Based
| Classifier | Description | Best For |
|---|---|---|
KNearestNeighbors |
Instance-based learning | Small datasets, no training |
Neural Networks
var result = await new AiModelBuilder<float, Tensor<float>, Tensor<float>>()
.ConfigureModel(new NeuralNetworkClassifier<float>(
inputFeatures: 784,
numClasses: 10,
complexity: NetworkComplexity.Medium))
.ConfigureOptimizer(new AdamOptimizer<float>(learningRate: 0.001f))
.ConfigureGpuAcceleration()
.BuildAsync(images, labels);
Data Preprocessing
Automatic Preprocessing
.ConfigurePreprocessing() // Applies StandardScaler by default
Custom Preprocessing
.ConfigurePreprocessing(new PreprocessingConfig
{
Scaler = new MinMaxScaler<double>(),
ImputeStrategy = ImputeStrategy.Mean,
HandleCategorical = true
})
Cross-Validation
Validate model performance with cross-validation:
.ConfigureCrossValidation(new KFoldCrossValidator<double>(k: 5))
// Or stratified (maintains class distribution)
.ConfigureCrossValidation(new StratifiedKFoldCrossValidator<double>(k: 5))
Access results:
var cv = result.CrossValidationResult;
Console.WriteLine($"Mean: {cv.MeanScore:P2}");
Console.WriteLine($"Std: {cv.StandardDeviation:P4}");
Console.WriteLine($"Folds: {string.Join(", ", cv.FoldScores.Select(s => $"{s:P2}"))}");
Evaluation Metrics
Confusion Matrix
// Use result.Predict() directly (facade pattern)
var predictions = testSamples.Select(s => result.Predict(s)).ToArray();
var cm = ConfusionMatrix.Compute(predictions, testLabels);
Console.WriteLine($"Accuracy: {cm.Accuracy:P2}");
Console.WriteLine($"Precision: {cm.Precision:P2}");
Console.WriteLine($"Recall: {cm.Recall:P2}");
Console.WriteLine($"F1 Score: {cm.F1Score:P2}");
ROC Curve (Binary)
var rocCurve = ROCCurve.Compute(predictions, testLabels);
Console.WriteLine($"AUC: {rocCurve.AUC:F4}");
Hyperparameter Tuning
Grid Search
var result = await new AiModelBuilder<double, double[], double>()
.ConfigureModel(new RandomForestClassifier<double>())
.ConfigureHyperparameterSearch(new GridSearchConfig
{
Parameters = new Dictionary<string, object[]>
{
["nEstimators"] = new object[] { 50, 100, 200 },
["maxDepth"] = new object[] { 5, 10, null }
},
ScoringMetric = ScoringMetric.Accuracy
})
.BuildAsync(features, labels);
AutoML
.ConfigureAutoML(new AutoMLConfig<double>
{
MaxTrials = 50,
TimeoutMinutes = 30,
Metric = AutoMLMetric.F1Score
})
Best Practices
- Start Simple: Use Logistic Regression as a baseline
- Check Class Balance: Use stratified sampling for imbalanced data
- Feature Scaling: Most algorithms benefit from scaled features
- Cross-Validate: Always use CV, not a single train/test split
- Monitor Overfitting: Compare train vs validation accuracy
Common Issues
Imbalanced Classes
.ConfigureClassWeights(ClassWeightStrategy.Balanced)
// Or custom weights
.ConfigureClassWeights(new Dictionary<int, double> { [0] = 1.0, [1] = 10.0 })
Overfitting
- Reduce model complexity
- Add regularization
- Increase training data
- Use dropout (for neural networks)
Underfitting
- Increase model complexity
- Add more features
- Reduce regularization