Class KnowledgeDistillationTrainer<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Standard knowledge distillation trainer that uses a fixed teacher model to train a student.
public class KnowledgeDistillationTrainer<T> : KnowledgeDistillationTrainerBase<T, Vector<T>, Vector<T>>, IKnowledgeDistillationTrainer<T, Vector<T>, Vector<T>>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
KnowledgeDistillationTrainer<T>
- Implements
- Inherited Members
Remarks
For Beginners: This is the standard implementation of knowledge distillation. It takes a large, accurate teacher model and uses it to train a smaller, faster student model.
The training process works as follows: 1. For each input, get predictions from both teacher and student 2. Compute distillation loss (how different are their predictions?) 3. Update student parameters to minimize this loss 4. Repeat until student learns to mimic teacher
Real-world Analogy: Think of this as an apprenticeship program. The master (teacher) demonstrates how to solve problems, and the apprentice (student) learns by trying to replicate the master's approach. The apprentice doesn't just learn the final answers, but also the reasoning process.
Benefits of Knowledge Distillation: - **Model Compression**: Deploy a 10x smaller model with >90% of original accuracy - **Faster Inference**: Smaller models run much faster on edge devices - **Ensemble Distillation**: Combine knowledge from multiple teachers into one student - **Transfer Learning**: Transfer knowledge across different architectures
Success Stories: - DistilBERT: 40% smaller than BERT, 97% of performance, 60% faster - MobileNet: Distilled from ResNet, 10x fewer parameters, deployable on phones - TinyBERT: 7.5x smaller than BERT, suitable for edge deployment
Constructors
KnowledgeDistillationTrainer(ITeacherModel<Vector<T>, Vector<T>>, IDistillationStrategy<T>, DistillationCheckpointConfig?, bool, double, int, int?)
Initializes a new instance of the KnowledgeDistillationTrainer class.
public KnowledgeDistillationTrainer(ITeacherModel<Vector<T>, Vector<T>> teacher, IDistillationStrategy<T> distillationStrategy, DistillationCheckpointConfig? checkpointConfig = null, bool useEarlyStopping = false, double earlyStoppingMinDelta = 0.001, int earlyStoppingPatience = 5, int? seed = null)
Parameters
teacherITeacherModel<Vector<T>, Vector<T>>The teacher model to learn from.
distillationStrategyIDistillationStrategy<T>The strategy for computing distillation loss.
checkpointConfigDistillationCheckpointConfigOptional checkpoint configuration for automatic model saving during training.
useEarlyStoppingboolearlyStoppingMinDeltadoubleearlyStoppingPatienceintseedint?Optional random seed for reproducibility.
Remarks
For Beginners: Create a trainer by providing: 1. A trained teacher model (already performing well on your task) 2. A distillation strategy (defines how to transfer knowledge) 3. Optional checkpoint configuration (for automatic model saving)
Example:
var teacher = new TeacherModelWrapper<double>(...);
var distillationLoss = new DistillationLoss<double>(temperature: 3.0, alpha: 0.3);
var trainer = new KnowledgeDistillationTrainer<double>(teacher, distillationLoss);
Example with automatic checkpointing:
var checkpointConfig = new DistillationCheckpointConfig
{
SaveEveryEpochs = 5,
KeepBestN = 3
};
var trainer = new KnowledgeDistillationTrainer<double>(
teacher,
distillationLoss,
checkpointConfig: checkpointConfig
);
Methods
GetTeacherPredictions(Vector<T>, int)
Gets teacher predictions by calling the teacher model's GetLogits method.
protected override Vector<T> GetTeacherPredictions(Vector<T> input, int index)
Parameters
inputVector<T>The input data.
indexintThe index in the training batch (unused for standard distillation).
Returns
- Vector<T>
Teacher's logit predictions.
Remarks
For Beginners: In standard distillation, we simply ask the teacher model for its predictions on each input. The teacher is frozen and doesn't change during training.