Table of Contents

Class AccuracyBasedAdaptiveStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Adaptive distillation strategy that adjusts temperature based on student accuracy.

public class AccuracyBasedAdaptiveStrategy<T> : AdaptiveDistillationStrategyBase<T>, IDistillationStrategy<T>, IAdaptiveDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
AccuracyBasedAdaptiveStrategy<T>
Implements
Inherited Members

Remarks

For Beginners: This strategy tracks whether the student is making correct predictions and adjusts temperature accordingly. When the student is correct, we use lower temperature (reinforce learning). When incorrect, we use higher temperature (provide softer, more exploratory targets).

Intuition: - **Correct Prediction** → Student learned this well → Lower temp (reinforce) - **Incorrect Prediction** → Student struggling → Higher temp (help learn)

Example: True label: [0, 1, 0] (class 1) Student predicts: [0.1, 0.8, 0.1] → Correct! → Low temperature Student predicts: [0.6, 0.3, 0.1] → Wrong! → High temperature

Best For: - Supervised learning with labeled data - When you want to focus more on difficult samples - Tracking which samples student struggles with

Requirements: Requires true labels to be provided in ComputeLoss/ComputeGradient calls. Without labels, falls back to confidence-based adaptation.

Performance Tracking: Uses exponential moving average of correctness: - 1.0 = consistently correct - 0.0 = consistently incorrect Temperature inversely proportional to performance.

Constructors

AccuracyBasedAdaptiveStrategy(double, double, double, double, double)

Initializes a new instance of the AccuracyBasedAdaptiveStrategy class.

public AccuracyBasedAdaptiveStrategy(double baseTemperature = 3, double alpha = 0.3, double minTemperature = 1, double maxTemperature = 5, double adaptationRate = 0.1)

Parameters

baseTemperature double

Base temperature for distillation (default: 3.0).

alpha double

Balance between hard and soft loss (default: 0.3).

minTemperature double

Minimum temperature (for correct predictions, default: 1.0).

maxTemperature double

Maximum temperature (for incorrect predictions, default: 5.0).

adaptationRate double

EMA rate for performance tracking (default: 0.1).

Remarks

For Beginners: This strategy requires true labels during training. Make sure to pass labels to ComputeLoss() and ComputeGradient().

Example:

var strategy = new AccuracyBasedAdaptiveStrategy<double>(
    minTemperature: 1.5,  // For samples student gets right
    maxTemperature: 6.0,  // For samples student gets wrong
    adaptationRate: 0.2   // How fast to adapt (higher = faster)
);

for (int i = 0; i < samples.Length; i++) { var teacherLogits = teacher.GetLogits(samples[i]); var studentLogits = student.Predict(samples[i]);

// IMPORTANT: Pass labels for accuracy tracking
var loss = strategy.ComputeLoss(studentLogits, teacherLogits, labels[i]);
strategy.UpdatePerformance(i, studentLogits, labels[i]);

}

Methods

ComputeAdaptiveTemperature(Vector<T>, Vector<T>)

Computes adaptive temperature based on student accuracy.

public override double ComputeAdaptiveTemperature(Vector<T> studentOutput, Vector<T> teacherOutput)

Parameters

studentOutput Vector<T>

Student's output logits.

teacherOutput Vector<T>

Teacher's output logits (not used in accuracy-based).

Returns

double

Adapted temperature based on historical accuracy.

Remarks

Algorithm: 1. Get historical performance for this sample (0.0 to 1.0) 2. If no history, use current confidence 3. Compute difficulty = 1 - performance 4. Map to temperature: temp = min + difficulty * (max - min)

This creates adaptive behavior: - High performance (0.8) → Low difficulty (0.2) → Lower temperature - Low performance (0.3) → High difficulty (0.7) → Higher temperature

Note: This uses historical performance (EMA), not current prediction. Call UpdatePerformance() regularly to keep tracking updated.

ComputePerformance(Vector<T>, Vector<T>?)

Computes performance based on prediction correctness.

protected override double ComputePerformance(Vector<T> studentOutput, Vector<T>? trueLabel)

Parameters

studentOutput Vector<T>
trueLabel Vector<T>

Returns

double

Remarks

Returns 1.0 if prediction is correct, 0.0 if incorrect. This is tracked with EMA to get average accuracy per sample.