Table of Contents

Class KnowledgeDistillationTrainer<T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Standard knowledge distillation trainer that uses a fixed teacher model to train a student.

public class KnowledgeDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>, IKnowledgeDistillationTrainer<T, Vector<T>, Vector<T>>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>
KnowledgeDistillationTrainer<T>
Implements
IKnowledgeDistillationTrainer<T, Vector<T>, Vector<T>>
Inherited Members

Remarks

For Beginners: This is the standard implementation of knowledge distillation. It takes a large, accurate teacher model and uses it to train a smaller, faster student model.

The training process works as follows: 1. For each input, get predictions from both teacher and student 2. Compute distillation loss (how different are their predictions?) 3. Update student parameters to minimize this loss 4. Repeat until student learns to mimic teacher

Real-world Analogy: Think of this as an apprenticeship program. The master (teacher) demonstrates how to solve problems, and the apprentice (student) learns by trying to replicate the master's approach. The apprentice doesn't just learn the final answers, but also the reasoning process.

Benefits of Knowledge Distillation: - **Model Compression**: Deploy a 10x smaller model with >90% of original accuracy - **Faster Inference**: Smaller models run much faster on edge devices - **Ensemble Distillation**: Combine knowledge from multiple teachers into one student - **Transfer Learning**: Transfer knowledge across different architectures

Success Stories: - DistilBERT: 40% smaller than BERT, 97% of performance, 60% faster - MobileNet: Distilled from ResNet, 10x fewer parameters, deployable on phones - TinyBERT: 7.5x smaller than BERT, suitable for edge deployment

Constructors

KnowledgeDistillationTrainer(ITeacherModel<Vector<T>, Vector<T>>, IDistillationStrategy<T>, DistillationCheckpointConfig?, bool, double, int, int?)

Initializes a new instance of the KnowledgeDistillationTrainer class.

public KnowledgeDistillationTrainer(ITeacherModel<Vector<T>, Vector<T>> teacher, IDistillationStrategy<T> distillationStrategy, DistillationCheckpointConfig? checkpointConfig = null, bool useEarlyStopping = false, double earlyStoppingMinDelta = 0.001, int earlyStoppingPatience = 5, int? seed = null)

Parameters

teacher ITeacherModel<Vector<T>, Vector<T>>

The teacher model to learn from.

distillationStrategy IDistillationStrategy<T>

The strategy for computing distillation loss.

checkpointConfig DistillationCheckpointConfig

Optional checkpoint configuration for automatic model saving during training.

useEarlyStopping bool
earlyStoppingMinDelta double
earlyStoppingPatience int
seed int?

Optional random seed for reproducibility.

Remarks

For Beginners: Create a trainer by providing: 1. A trained teacher model (already performing well on your task) 2. A distillation strategy (defines how to transfer knowledge) 3. Optional checkpoint configuration (for automatic model saving)

Example:

var teacher = new TeacherModelWrapper<double>(...);
var distillationLoss = new DistillationLoss<double>(temperature: 3.0, alpha: 0.3);
var trainer = new KnowledgeDistillationTrainer<double>(teacher, distillationLoss);

Example with automatic checkpointing:

var checkpointConfig = new DistillationCheckpointConfig
{
    SaveEveryEpochs = 5,
    KeepBestN = 3
};
var trainer = new KnowledgeDistillationTrainer<double>(
    teacher,
    distillationLoss,
    checkpointConfig: checkpointConfig
);

Methods

GetTeacherPredictions(Vector<T>, int)

Gets teacher predictions by calling the teacher model's GetLogits method.

protected override Vector<T> GetTeacherPredictions(Vector<T> input, int index)

Parameters

input Vector<T>

The input data.

index int

The index in the training batch (unused for standard distillation).

Returns

Vector<T>

Teacher's logit predictions.

Remarks

For Beginners: In standard distillation, we simply ask the teacher model for its predictions on each input. The teacher is frozen and doesn't change during training.