Table of Contents

Class SelfDistillationTrainer<T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Implements self-distillation where a model acts as its own teacher to improve calibration and generalization.

public class SelfDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>, IKnowledgeDistillationTrainer<T, Vector<T>, Vector<T>>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>
SelfDistillationTrainer<T>
Implements
IKnowledgeDistillationTrainer<T, Vector<T>, Vector<T>>
Inherited Members

Remarks

For Beginners: Self-distillation is a clever technique where a model learns from itself! Instead of using a separate larger teacher, you train a model normally, then use it as a teacher to train itself again. This often improves: - **Calibration**: Model confidence matches actual accuracy - **Generalization**: Better performance on unseen data - **Robustness**: Less sensitive to noisy labels or adversarial examples

How It Works: 1. Train model normally on hard labels (standard training) 2. Save the trained model's predictions 3. Retrain the model using its own soft predictions as teacher 4. Repeat for multiple generations if desired

Real-world Analogy: Imagine studying for an exam, then teaching the material to yourself as if you were a student. By explaining concepts in your own words, you deepen your understanding and identify gaps in your knowledge. Self-distillation works similarly for neural networks.

Variants: - **Iterative Self-Distillation**: Multiple rounds of self-teaching - **Born-Again Networks**: Same architecture, trained from scratch with self as teacher - **Online Self-Distillation**: Student learns from earlier checkpoints of itself

Benefits: - No need for a separate teacher model - Improves calibration without model compression - Can be combined with data augmentation for better regularization - Often provides 1-3% accuracy improvement for free

When to Use: - You want better calibrated predictions - You have limited model capacity (can't afford a larger teacher) - You want to improve an existing trained model - You're training on noisy or imperfect labels

References: - Furlanello, T., et al. (2018). Born Again Neural Networks. ICML. - Zhang, L., et al. (2019). Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self-Distillation.

Constructors

SelfDistillationTrainer(IDistillationStrategy<T>, int, int?)

Initializes a new instance of the SelfDistillationTrainer class.

public SelfDistillationTrainer(IDistillationStrategy<T> distillationStrategy, int generations = 1, int? seed = null)

Parameters

distillationStrategy IDistillationStrategy<T>

The strategy for computing distillation loss.

generations int

Number of self-distillation generations (default 1). More generations can improve performance but take longer to train.

seed int?

Optional random seed for reproducibility.

Remarks

For Beginners: Generations control how many times the model relearns from itself: - 1 generation: Train normally (standard training, no self-distillation) - 2 generations: Train, then retrain using self as teacher (first self-distillation) - 3 generations: Train → self-teach → self-teach again - More generations: Diminishing returns, usually not worth it beyond 2-3

Example:

var distillationLoss = new DistillationLoss<double>(temperature: 3.0, alpha: 0.5);
var selfTrainer = new SelfDistillationTrainer<double>(distillationLoss, generations: 2);

Properties

EMADecay

Gets or sets the EMA decay rate (default 0.99). Higher values give more weight to history.

public double EMADecay { get; set; }

Property Value

double

Exceptions

ArgumentOutOfRangeException

Thrown when value is not between 0 and 1.

UseEMA

Gets or sets whether to use exponential moving average for teacher predictions.

public bool UseEMA { get; set; }

Property Value

bool

Remarks

For Beginners: EMA smooths out the teacher's predictions over time, making them more stable and reliable. This can improve training stability.

Methods

GetTeacherPredictions(Vector<T>, int)

Gets teacher predictions from the cached predictions dictionary (for self-distillation).

protected override Vector<T> GetTeacherPredictions(Vector<T> input, int index)

Parameters

input Vector<T>

The input data to look up cached predictions for.

index int

The index in the training batch (unused - we use input for lookup).

Returns

Vector<T>

Cached teacher prediction for this input.

Remarks

For Self-Distillation: Instead of calling a separate teacher model, we return predictions that were cached from the previous generation. We use the input itself as the key (via reference equality) to handle shuffled batches correctly.

Generation 0 Handling: When no cached predictions exist (first generation), we use the student's own predictions as the teacher. This makes distillation a no-op for generation 0, effectively training normally. This avoids dimension mismatches since the placeholder teacher has OutputDimension = 0.

TrainMultipleGenerations(Func<Vector<T>, Vector<T>>, Action<Vector<T>>, Vector<Vector<T>>, Vector<Vector<T>>, int, int, Action<int, T>?)

Performs self-distillation training for the specified number of generations.

public void TrainMultipleGenerations(Func<Vector<T>, Vector<T>> modelForward, Action<Vector<T>> modelBackward, Vector<Vector<T>> trainInputs, Vector<Vector<T>> trainLabels, int epochs, int batchSize = 32, Action<int, T>? onGenerationComplete = null)

Parameters

modelForward Func<Vector<T>, Vector<T>>

Function to perform forward pass and get logits.

modelBackward Action<Vector<T>>

Function to perform backward pass with gradients.

trainInputs Vector<Vector<T>>

Training input data.

trainLabels Vector<Vector<T>>

Training labels.

epochs int

Number of epochs per generation.

batchSize int

Batch size for training.

onGenerationComplete Action<int, T>

Optional callback invoked after each generation with (generation, avgLoss).

Remarks

For Beginners: This method runs the complete self-distillation process: 1. **Generation 0**: Train model normally (if starting from scratch) 2. **Generation 1**: Retrain using self as teacher 3. **Generation 2+**: Continue if requested

Each generation: - Saves current model predictions as "teacher" - Retrains model to match both teacher predictions and true labels - Typically sees 0.5-2% improvement per generation

Training Tips: - Use temperature 2-4 (lower than standard distillation) - Set alpha = 0.5 (equal weight to self and labels) - Train for fewer epochs in later generations (half of first) - Watch for overfitting in later generations