Class AccuracyBasedAdaptiveStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Adaptive distillation strategy that adjusts temperature based on student accuracy.
public class AccuracyBasedAdaptiveStrategy<T> : AdaptiveDistillationStrategyBase<T>, IDistillationStrategy<T>, IAdaptiveDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
AccuracyBasedAdaptiveStrategy<T>
- Implements
- Inherited Members
Remarks
For Beginners: This strategy tracks whether the student is making correct predictions and adjusts temperature accordingly. When the student is correct, we use lower temperature (reinforce learning). When incorrect, we use higher temperature (provide softer, more exploratory targets).
Intuition: - **Correct Prediction** → Student learned this well → Lower temp (reinforce) - **Incorrect Prediction** → Student struggling → Higher temp (help learn)
Example: True label: [0, 1, 0] (class 1) Student predicts: [0.1, 0.8, 0.1] → Correct! → Low temperature Student predicts: [0.6, 0.3, 0.1] → Wrong! → High temperature
Best For: - Supervised learning with labeled data - When you want to focus more on difficult samples - Tracking which samples student struggles with
Requirements: Requires true labels to be provided in ComputeLoss/ComputeGradient calls. Without labels, falls back to confidence-based adaptation.
Performance Tracking: Uses exponential moving average of correctness: - 1.0 = consistently correct - 0.0 = consistently incorrect Temperature inversely proportional to performance.
Constructors
AccuracyBasedAdaptiveStrategy(double, double, double, double, double)
Initializes a new instance of the AccuracyBasedAdaptiveStrategy class.
public AccuracyBasedAdaptiveStrategy(double baseTemperature = 3, double alpha = 0.3, double minTemperature = 1, double maxTemperature = 5, double adaptationRate = 0.1)
Parameters
baseTemperaturedoubleBase temperature for distillation (default: 3.0).
alphadoubleBalance between hard and soft loss (default: 0.3).
minTemperaturedoubleMinimum temperature (for correct predictions, default: 1.0).
maxTemperaturedoubleMaximum temperature (for incorrect predictions, default: 5.0).
adaptationRatedoubleEMA rate for performance tracking (default: 0.1).
Remarks
For Beginners: This strategy requires true labels during training. Make sure to pass labels to ComputeLoss() and ComputeGradient().
Example:
var strategy = new AccuracyBasedAdaptiveStrategy<double>(
minTemperature: 1.5, // For samples student gets right
maxTemperature: 6.0, // For samples student gets wrong
adaptationRate: 0.2 // How fast to adapt (higher = faster)
);
for (int i = 0; i < samples.Length; i++)
{
var teacherLogits = teacher.GetLogits(samples[i]);
var studentLogits = student.Predict(samples[i]);
// IMPORTANT: Pass labels for accuracy tracking
var loss = strategy.ComputeLoss(studentLogits, teacherLogits, labels[i]);
strategy.UpdatePerformance(i, studentLogits, labels[i]);
}
Methods
ComputeAdaptiveTemperature(Vector<T>, Vector<T>)
Computes adaptive temperature based on student accuracy.
public override double ComputeAdaptiveTemperature(Vector<T> studentOutput, Vector<T> teacherOutput)
Parameters
studentOutputVector<T>Student's output logits.
teacherOutputVector<T>Teacher's output logits (not used in accuracy-based).
Returns
- double
Adapted temperature based on historical accuracy.
Remarks
Algorithm: 1. Get historical performance for this sample (0.0 to 1.0) 2. If no history, use current confidence 3. Compute difficulty = 1 - performance 4. Map to temperature: temp = min + difficulty * (max - min)
This creates adaptive behavior: - High performance (0.8) → Low difficulty (0.2) → Lower temperature - Low performance (0.3) → High difficulty (0.7) → Higher temperature
Note: This uses historical performance (EMA), not current prediction. Call UpdatePerformance() regularly to keep tracking updated.
ComputePerformance(Vector<T>, Vector<T>?)
Computes performance based on prediction correctness.
protected override double ComputePerformance(Vector<T> studentOutput, Vector<T>? trueLabel)
Parameters
studentOutputVector<T>trueLabelVector<T>
Returns
Remarks
Returns 1.0 if prediction is correct, 0.0 if incorrect. This is tracked with EMA to get average accuracy per sample.