Interface IKnowledgeDistillationTrainer<T, TInput, TOutput>
- Namespace
- AiDotNet.Interfaces
- Assembly
- AiDotNet.dll
Defines the contract for knowledge distillation trainers that train student models using knowledge transferred from teacher models.
public interface IKnowledgeDistillationTrainer<T, TInput, TOutput>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
TInputThe input data type (e.g., Matrix<T> for tabular data, Tensor<T> for images).
TOutputThe output data type (e.g., Matrix<T> for batch outputs, Tensor<T> for structured outputs).
Remarks
For Beginners: A knowledge distillation trainer orchestrates the process of transferring knowledge from a large, accurate teacher model to a smaller, faster student model.
Why an interface? - **Flexibility**: Multiple trainer implementations (standard, self-distillation, multi-teacher, etc.) - **Testability**: Easy to mock for unit testing - **Extensibility**: New training strategies can be added without breaking existing code - **Dependency Injection**: Can be injected into other components
Common Implementations: - **Standard Trainer**: Single teacher → single student - **Self-Distillation Trainer**: Model teaches itself (improves calibration) - **Multi-Teacher Trainer**: Multiple teachers → one student (ensemble distillation) - **Online Trainer**: Teacher updates during student training - **Mutual Learning Trainer**: Multiple students learn from each other
Properties
DistillationStrategy
Gets the distillation strategy used for computing loss and gradients.
IDistillationStrategy<T> DistillationStrategy { get; }
Property Value
Teacher
Gets the teacher model used for distillation.
ITeacherModel<TInput, TOutput> Teacher { get; }
Property Value
- ITeacherModel<TInput, TOutput>
Methods
Evaluate(Func<TInput, TOutput>, Vector<TInput>, Vector<TOutput>)
Evaluates the student model's accuracy on test data.
double Evaluate(Func<TInput, TOutput> studentForward, Vector<TInput> testInputs, Vector<TOutput> testLabels)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass on student model.
testInputsVector<TInput>Test input data.
testLabelsVector<TOutput>Test labels (one-hot encoded).
Returns
- double
Classification accuracy as a percentage (0-100).
Remarks
For Beginners: This measures how well the student has learned. Returns a percentage between 0 and 100, where 100.0 means perfect accuracy.
Train(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?, int, int, Action<int, T>?)
Trains the student model for multiple epochs.
void Train(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels, int epochs, int batchSize = 32, Action<int, T>? onEpochComplete = null)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass on student model.
studentBackwardAction<TOutput>Function to perform backward pass on student model.
trainInputsVector<TInput>Training input data.
trainLabelsVector<TOutput>Training labels (optional). If null, uses only teacher supervision.
epochsintNumber of epochs to train.
batchSizeintBatch size for training.
onEpochCompleteAction<int, T>Optional callback invoked after each epoch with (epoch, avgLoss).
Remarks
For Beginners: An epoch is one complete pass through the training data. This method runs the complete training process for the specified number of epochs.
TrainBatch(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?)
Trains the student model on a single batch of data.
T TrainBatch(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> inputs, Vector<TOutput>? trueLabels = null)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass on student model.
studentBackwardAction<TOutput>Function to perform backward pass on student model (takes gradient).
inputsVector<TInput>Input batch.
trueLabelsVector<TOutput>Ground truth labels (optional). If null, uses only teacher supervision.
Returns
- T
Average loss for the batch.
Remarks
For Beginners: This method processes one batch of training data: 1. Get predictions from both teacher and student 2. Compute distillation loss 3. Update student weights via backpropagation