Class KnowledgeDistillationTrainerBase<T, TInput, TOutput>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Abstract base class for all knowledge distillation trainers. Provides common functionality for training loops, data shuffling, validation, and evaluation.
public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKnowledgeDistillationTrainer<T, TInput, TOutput>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
TInputThe input data type.
TOutputThe output data type.
- Inheritance
-
KnowledgeDistillationTrainerBase<T, TInput, TOutput>
- Implements
-
IKnowledgeDistillationTrainer<T, TInput, TOutput>
- Derived
- Inherited Members
Remarks
For Beginners: This base class implements the common training workflow shared by all distillation trainers. Specific trainer types (standard, self-distillation, online, etc.) inherit from this and customize only the parts that differ.
Design Pattern: Template Method Pattern - the base class defines the training algorithm structure, and derived classes fill in specific steps.
Common Functionality Provided: - Data shuffling using Fisher-Yates algorithm (O(n) efficiency) - Epoch and batch management - Validation after each epoch - Progress callbacks - Evaluation metrics (accuracy, loss) - Teacher and strategy property management
Derived Classes Override: - GetTeacherPredictions(): How to obtain teacher outputs for training - OnEpochStart(): Custom logic before each epoch - OnEpochEnd(): Custom logic after each epoch - OnTrainingStart(): Custom logic before training begins - OnTrainingEnd(): Custom logic after training completes
Constructors
KnowledgeDistillationTrainerBase(ITeacherModel<TInput, TOutput>, IDistillationStrategy<T>, DistillationCheckpointConfig?, bool, double, int, int?)
Initializes a new instance of the KnowledgeDistillationTrainerBase class.
protected KnowledgeDistillationTrainerBase(ITeacherModel<TInput, TOutput> teacher, IDistillationStrategy<T> distillationStrategy, DistillationCheckpointConfig? checkpointConfig = null, bool useEarlyStopping = false, double earlyStoppingMinDelta = 0.001, int earlyStoppingPatience = 10, int? seed = null)
Parameters
teacherITeacherModel<TInput, TOutput>The teacher model.
distillationStrategyIDistillationStrategy<T>The distillation strategy.
checkpointConfigDistillationCheckpointConfigOptional checkpoint configuration for automatic model saving during training.
useEarlyStoppingboolEnable early stopping based on validation loss.
earlyStoppingMinDeltadoubleMinimum improvement required to count as progress.
earlyStoppingPatienceintNumber of epochs without improvement before stopping.
seedint?Optional random seed for reproducibility.
Remarks
For Beginners: The teacher and strategy are the core components: - Teacher: Provides the "expert" knowledge to transfer - Strategy: Defines how to measure and optimize the knowledge transfer
Automatic Checkpointing: To enable automatic checkpointing, pass a DistillationCheckpointConfig instance. If null (default), no automatic checkpointing occurs. When enabled, the trainer will automatically: - Save checkpoints based on your configuration (e.g., every 5 epochs) - Keep only the best N checkpoints to save disk space - Load the best checkpoint after training completes
Example with Checkpointing:
var config = new DistillationCheckpointConfig
{
SaveEveryEpochs = 5,
KeepBestN = 3
};
var trainer = new KnowledgeDistillationTrainer(teacher, strategy, checkpointConfig: config);
Fields
NumOps
Gets the numeric operations helper for the type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Random
Gets the random number generator for data shuffling.
protected readonly Random Random
Field Value
Properties
DistillationStrategy
Gets the distillation strategy for computing loss and gradients.
public IDistillationStrategy<T> DistillationStrategy { get; protected set; }
Property Value
Teacher
Gets the teacher model used for distillation.
public ITeacherModel<TInput, TOutput> Teacher { get; protected set; }
Property Value
- ITeacherModel<TInput, TOutput>
Methods
ArgMax(Vector<T>)
Finds the index of the maximum value in a vector (argmax).
protected int ArgMax(Vector<T> vector)
Parameters
vectorVector<T>The vector to search.
Returns
- int
Index of the maximum value.
Evaluate(Func<TInput, TOutput>, Vector<TInput>, Vector<TOutput>)
Evaluates the student model's accuracy on a dataset.
public virtual double Evaluate(Func<TInput, TOutput> studentForward, Vector<TInput> inputs, Vector<TOutput> trueLabels)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass through student model.
inputsVector<TInput>Evaluation input data.
trueLabelsVector<TOutput>True labels for evaluation.
Returns
- double
Accuracy as a fraction (0-1).
Remarks
For Beginners: Evaluation measures how well the student performs: - For each input, get student's prediction - Compare with true label (argmax for classification) - Calculate fraction of correct predictions
This is used to monitor training progress and detect overfitting.
FisherYatesShuffle(int)
Generates a random permutation of indices using Fisher-Yates shuffle.
protected int[] FisherYatesShuffle(int length)
Parameters
lengthintThe length of the array to shuffle.
Returns
- int[]
An array of shuffled indices.
GetTeacherPredictions(TInput, int)
Gets teacher predictions for a given input. Abstract method to be implemented by derived classes.
protected abstract TOutput GetTeacherPredictions(TInput input, int index)
Parameters
inputTInputThe input data.
indexintThe index of this input in the training batch (useful for caching).
Returns
- TOutput
Teacher's output predictions.
Remarks
For Derived Classes: - Standard trainers: Call teacher.GetLogits(input) - Self-distillation: Return cached predictions from previous generation - Online distillation: Get predictions from dynamically updated teacher - Ensemble: Combine predictions from multiple teachers
IsCorrectPrediction(TOutput, TOutput)
Determines if a prediction matches the true label.
protected virtual bool IsCorrectPrediction(TOutput prediction, TOutput trueLabel)
Parameters
predictionTOutputStudent's prediction.
trueLabelTOutputTrue label.
Returns
- bool
True if prediction is correct.
Remarks
For Classification: Uses argmax to find predicted class and compares with true class. The conversion to Vector<T> is handled by ConversionsHelper, so this works for any TOutput type (Vector<T>, Tensor<T>, T[], or scalar T).
Override This: If you need different evaluation criteria (e.g., regression).
OnEpochEnd(int, T)
Called after each epoch completes. Override for custom per-epoch cleanup/logging.
protected virtual void OnEpochEnd(int epoch, T avgLoss)
Parameters
epochintCurrent epoch number (0-indexed).
avgLossTAverage loss for this epoch.
Remarks
Use Cases: - Log epoch metrics - Save checkpoints - Update adaptive parameters - Implement early stopping logic
IMPORTANT: This base implementation calls Reset() on RelationalDistillationStrategy to flush partial batches and prevent buffer leakage between epochs. Derived classes should call base.OnEpochEnd() if they override this method.
Automatic Checkpointing: If checkpoint configuration was provided to the constructor, this method automatically saves checkpoints based on your configuration.
OnEpochStart(int, Vector<TInput>, Vector<TOutput>?)
Called before each epoch starts. Override for custom per-epoch initialization.
protected virtual void OnEpochStart(int epoch, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
Parameters
epochintCurrent epoch number (0-indexed).
trainInputsVector<TInput>Training inputs.
trainLabelsVector<TOutput>Training labels.
Remarks
Use Cases: - Update learning rate schedules - Adjust temperature schedules - Update curriculum difficulty - Refresh teacher in online distillation
OnTrainingEnd(Vector<TInput>, Vector<TOutput>?)
Called after training completes. Override for custom cleanup logic.
protected virtual void OnTrainingEnd(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
Parameters
trainInputsVector<TInput>Training inputs.
trainLabelsVector<TOutput>Training labels.
Remarks
Use Cases: - Clear caches - Save final checkpoints - Log final metrics - Free temporary resources
Automatic Checkpointing: If checkpoint configuration was provided to the constructor, this method automatically loads the best checkpoint (based on validation metrics) after training completes.
OnTrainingStart(Vector<TInput>, Vector<TOutput>?)
Called before training starts. Override for custom initialization logic.
protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)
Parameters
trainInputsVector<TInput>Training inputs.
trainLabelsVector<TOutput>Training labels.
Remarks
Use Cases: - Cache teacher predictions (for efficiency) - Initialize EMA buffers (for self-distillation) - Setup curriculum schedules - Allocate temporary buffers
Automatic Checkpointing: If checkpoint configuration was provided to the constructor, this method automatically initializes the checkpoint manager.
OnValidationComplete(int, double)
Called after validation completes for an epoch. Override for custom validation handling.
protected virtual void OnValidationComplete(int epoch, double accuracy)
Parameters
Remarks
Use Cases: - Log validation metrics - Implement early stopping - Track best model - Adjust hyperparameters based on validation performance
Automatic Checkpointing: If CheckpointConfig is set, this method automatically tracks validation metrics for best checkpoint selection.
ShuffleData(Vector<TInput>, Vector<TOutput>?)
Shuffles training data using Fisher-Yates algorithm.
protected virtual (Vector<TInput> shuffledInputs, Vector<TOutput>? shuffledLabels) ShuffleData(Vector<TInput> inputs, Vector<TOutput>? labels)
Parameters
inputsVector<TInput>Input data to shuffle.
labelsVector<TOutput>Labels to shuffle (maintains alignment with inputs).
Returns
- (Vector<TInput> shuffledInputs, Vector<TOutput> shuffledLabels)
Tuple of shuffled inputs and labels.
Remarks
For Beginners: Data shuffling is important because: - Prevents model from learning order-dependent patterns - Improves gradient descent convergence - Reduces overfitting to batch ordering
Fisher-Yates is O(n) compared to O(n log n) for sort-based shuffling.
Train(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?, int, int, Vector<TInput>?, Vector<TOutput>?, ICheckpointableModel?, Action<int, T>?)
Trains the student model for multiple epochs using knowledge distillation.
public virtual void Train(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels = null, int epochs = 20, int batchSize = 32, Vector<TInput>? validationInputs = null, Vector<TOutput>? validationLabels = null, ICheckpointableModel? student = null, Action<int, T>? onEpochComplete = null)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass through student model.
studentBackwardAction<TOutput>Function to perform backward pass and update weights.
trainInputsVector<TInput>Training input data.
trainLabelsVector<TOutput>Training labels (optional for pure soft distillation).
epochsintNumber of training epochs.
batchSizeintBatch size for mini-batch training.
validationInputsVector<TInput>Optional validation inputs for monitoring.
validationLabelsVector<TOutput>Optional validation labels.
studentICheckpointableModelOptional student model for automatic checkpointing (must implement ICheckpointableModel).
onEpochCompleteAction<int, T>Optional callback invoked after each epoch with (epoch, avgLoss).
Remarks
For Beginners: This method orchestrates the complete training process: 1. Initialize (OnTrainingStart) 2. For each epoch: a. Shuffle training data b. Process data in batches c. Validate if requested d. Invoke callbacks 3. Cleanup (OnTrainingEnd)
Training Tips: - Use batch sizes that fit in memory (32-128 typical) - Monitor validation loss to detect overfitting - Invoke callbacks to log progress or save checkpoints
Automatic Checkpointing: If a checkpoint configuration was provided to the constructor and you pass the student model parameter, automatic checkpointing will be enabled.
Train(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?, int, int, Action<int, T>?)
Trains the student model for multiple epochs (interface-compliant overload).
public virtual void Train(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels, int epochs, int batchSize = 32, Action<int, T>? onEpochComplete = null)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass on student model.
studentBackwardAction<TOutput>Function to perform backward pass on student model.
trainInputsVector<TInput>Training input data.
trainLabelsVector<TOutput>Training labels (optional). If null, uses only teacher supervision.
epochsintNumber of epochs to train.
batchSizeintBatch size for training.
onEpochCompleteAction<int, T>Optional callback invoked after each epoch with (epoch, avgLoss).
Remarks
For Beginners: This overload matches the interface contract and delegates to the extended version with no validation data.
TrainBatch(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?)
Trains the student model on a single batch using knowledge distillation.
public virtual T TrainBatch(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> inputs, Vector<TOutput>? trueLabels = null)
Parameters
studentForwardFunc<TInput, TOutput>Function to perform forward pass through student model.
studentBackwardAction<TOutput>Function to perform backward pass and update weights.
inputsVector<TInput>Batch of input data.
trueLabelsVector<TOutput>Optional true labels for the batch.
Returns
- T
Average loss for the batch.
Remarks
For Beginners: Training a batch involves: 1. Get student predictions (forward pass) 2. Get teacher predictions (from teacher model or cache) 3. Compute distillation loss (how different student is from teacher) 4. Compute gradients (how to improve student) 5. Update student weights (backward pass)