Class SelfDistillationTrainer<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Implements self-distillation where a model acts as its own teacher to improve calibration and generalization.
public class SelfDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>, IKnowledgeDistillationTrainer<T, Vector<T>, Vector<T>>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
SelfDistillationTrainer<T>
- Implements
- Inherited Members
Remarks
For Beginners: Self-distillation is a clever technique where a model learns from itself! Instead of using a separate larger teacher, you train a model normally, then use it as a teacher to train itself again. This often improves: - **Calibration**: Model confidence matches actual accuracy - **Generalization**: Better performance on unseen data - **Robustness**: Less sensitive to noisy labels or adversarial examples
How It Works: 1. Train model normally on hard labels (standard training) 2. Save the trained model's predictions 3. Retrain the model using its own soft predictions as teacher 4. Repeat for multiple generations if desired
Real-world Analogy: Imagine studying for an exam, then teaching the material to yourself as if you were a student. By explaining concepts in your own words, you deepen your understanding and identify gaps in your knowledge. Self-distillation works similarly for neural networks.
Variants: - **Iterative Self-Distillation**: Multiple rounds of self-teaching - **Born-Again Networks**: Same architecture, trained from scratch with self as teacher - **Online Self-Distillation**: Student learns from earlier checkpoints of itself
Benefits: - No need for a separate teacher model - Improves calibration without model compression - Can be combined with data augmentation for better regularization - Often provides 1-3% accuracy improvement for free
When to Use: - You want better calibrated predictions - You have limited model capacity (can't afford a larger teacher) - You want to improve an existing trained model - You're training on noisy or imperfect labels
References: - Furlanello, T., et al. (2018). Born Again Neural Networks. ICML. - Zhang, L., et al. (2019). Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self-Distillation.
Constructors
SelfDistillationTrainer(IDistillationStrategy<T>, int, int?)
Initializes a new instance of the SelfDistillationTrainer class.
public SelfDistillationTrainer(IDistillationStrategy<T> distillationStrategy, int generations = 1, int? seed = null)
Parameters
distillationStrategyIDistillationStrategy<T>The strategy for computing distillation loss.
generationsintNumber of self-distillation generations (default 1). More generations can improve performance but take longer to train.
seedint?Optional random seed for reproducibility.
Remarks
For Beginners: Generations control how many times the model relearns from itself: - 1 generation: Train normally (standard training, no self-distillation) - 2 generations: Train, then retrain using self as teacher (first self-distillation) - 3 generations: Train → self-teach → self-teach again - More generations: Diminishing returns, usually not worth it beyond 2-3
Example:
var distillationLoss = new DistillationLoss<double>(temperature: 3.0, alpha: 0.5);
var selfTrainer = new SelfDistillationTrainer<double>(distillationLoss, generations: 2);
Properties
EMADecay
Gets or sets the EMA decay rate (default 0.99). Higher values give more weight to history.
public double EMADecay { get; set; }
Property Value
Exceptions
- ArgumentOutOfRangeException
Thrown when value is not between 0 and 1.
UseEMA
Gets or sets whether to use exponential moving average for teacher predictions.
public bool UseEMA { get; set; }
Property Value
Remarks
For Beginners: EMA smooths out the teacher's predictions over time, making them more stable and reliable. This can improve training stability.
Methods
GetTeacherPredictions(Vector<T>, int)
Gets teacher predictions from the cached predictions dictionary (for self-distillation).
protected override Vector<T> GetTeacherPredictions(Vector<T> input, int index)
Parameters
inputVector<T>The input data to look up cached predictions for.
indexintThe index in the training batch (unused - we use input for lookup).
Returns
- Vector<T>
Cached teacher prediction for this input.
Remarks
For Self-Distillation: Instead of calling a separate teacher model, we return predictions that were cached from the previous generation. We use the input itself as the key (via reference equality) to handle shuffled batches correctly.
Generation 0 Handling: When no cached predictions exist (first generation), we use the student's own predictions as the teacher. This makes distillation a no-op for generation 0, effectively training normally. This avoids dimension mismatches since the placeholder teacher has OutputDimension = 0.
TrainMultipleGenerations(Func<Vector<T>, Vector<T>>, Action<Vector<T>>, Vector<Vector<T>>, Vector<Vector<T>>, int, int, Action<int, T>?)
Performs self-distillation training for the specified number of generations.
public void TrainMultipleGenerations(Func<Vector<T>, Vector<T>> modelForward, Action<Vector<T>> modelBackward, Vector<Vector<T>> trainInputs, Vector<Vector<T>> trainLabels, int epochs, int batchSize = 32, Action<int, T>? onGenerationComplete = null)
Parameters
modelForwardFunc<Vector<T>, Vector<T>>Function to perform forward pass and get logits.
modelBackwardAction<Vector<T>>Function to perform backward pass with gradients.
trainInputsVector<Vector<T>>Training input data.
trainLabelsVector<Vector<T>>Training labels.
epochsintNumber of epochs per generation.
batchSizeintBatch size for training.
onGenerationCompleteAction<int, T>Optional callback invoked after each generation with (generation, avgLoss).
Remarks
For Beginners: This method runs the complete self-distillation process: 1. **Generation 0**: Train model normally (if starting from scratch) 2. **Generation 1**: Retrain using self as teacher 3. **Generation 2+**: Continue if requested
Each generation: - Saves current model predictions as "teacher" - Retrains model to match both teacher predictions and true labels - Typically sees 0.5-2% improvement per generation
Training Tips: - Use temperature 2-4 (lower than standard distillation) - Set alpha = 0.5 (equal weight to self and labels) - Train for fewer epochs in later generations (half of first) - Watch for overfitting in later generations