Table of Contents

Class KnowledgeDistillationTrainerBase<T, TInput, TOutput>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Abstract base class for all knowledge distillation trainers. Provides common functionality for training loops, data shuffling, validation, and evaluation.

public abstract class KnowledgeDistillationTrainerBase<T, TInput, TOutput> : IKnowledgeDistillationTrainer<T, TInput, TOutput>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

TInput

The input data type.

TOutput

The output data type.

Inheritance
KnowledgeDistillationTrainerBase<T, TInput, TOutput>
Implements
IKnowledgeDistillationTrainer<T, TInput, TOutput>
Derived
Inherited Members

Remarks

For Beginners: This base class implements the common training workflow shared by all distillation trainers. Specific trainer types (standard, self-distillation, online, etc.) inherit from this and customize only the parts that differ.

Design Pattern: Template Method Pattern - the base class defines the training algorithm structure, and derived classes fill in specific steps.

Common Functionality Provided: - Data shuffling using Fisher-Yates algorithm (O(n) efficiency) - Epoch and batch management - Validation after each epoch - Progress callbacks - Evaluation metrics (accuracy, loss) - Teacher and strategy property management

Derived Classes Override: - GetTeacherPredictions(): How to obtain teacher outputs for training - OnEpochStart(): Custom logic before each epoch - OnEpochEnd(): Custom logic after each epoch - OnTrainingStart(): Custom logic before training begins - OnTrainingEnd(): Custom logic after training completes

Constructors

KnowledgeDistillationTrainerBase(ITeacherModel<TInput, TOutput>, IDistillationStrategy<T>, DistillationCheckpointConfig?, bool, double, int, int?)

Initializes a new instance of the KnowledgeDistillationTrainerBase class.

protected KnowledgeDistillationTrainerBase(ITeacherModel<TInput, TOutput> teacher, IDistillationStrategy<T> distillationStrategy, DistillationCheckpointConfig? checkpointConfig = null, bool useEarlyStopping = false, double earlyStoppingMinDelta = 0.001, int earlyStoppingPatience = 10, int? seed = null)

Parameters

teacher ITeacherModel<TInput, TOutput>

The teacher model.

distillationStrategy IDistillationStrategy<T>

The distillation strategy.

checkpointConfig DistillationCheckpointConfig

Optional checkpoint configuration for automatic model saving during training.

useEarlyStopping bool

Enable early stopping based on validation loss.

earlyStoppingMinDelta double

Minimum improvement required to count as progress.

earlyStoppingPatience int

Number of epochs without improvement before stopping.

seed int?

Optional random seed for reproducibility.

Remarks

For Beginners: The teacher and strategy are the core components: - Teacher: Provides the "expert" knowledge to transfer - Strategy: Defines how to measure and optimize the knowledge transfer

Automatic Checkpointing: To enable automatic checkpointing, pass a DistillationCheckpointConfig instance. If null (default), no automatic checkpointing occurs. When enabled, the trainer will automatically: - Save checkpoints based on your configuration (e.g., every 5 epochs) - Keep only the best N checkpoints to save disk space - Load the best checkpoint after training completes

Example with Checkpointing:

var config = new DistillationCheckpointConfig
{
    SaveEveryEpochs = 5,
    KeepBestN = 3
};
var trainer = new KnowledgeDistillationTrainer(teacher, strategy, checkpointConfig: config);

Fields

NumOps

Gets the numeric operations helper for the type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Random

Gets the random number generator for data shuffling.

protected readonly Random Random

Field Value

Random

Properties

DistillationStrategy

Gets the distillation strategy for computing loss and gradients.

public IDistillationStrategy<T> DistillationStrategy { get; protected set; }

Property Value

IDistillationStrategy<T>

Teacher

Gets the teacher model used for distillation.

public ITeacherModel<TInput, TOutput> Teacher { get; protected set; }

Property Value

ITeacherModel<TInput, TOutput>

Methods

ArgMax(Vector<T>)

Finds the index of the maximum value in a vector (argmax).

protected int ArgMax(Vector<T> vector)

Parameters

vector Vector<T>

The vector to search.

Returns

int

Index of the maximum value.

Evaluate(Func<TInput, TOutput>, Vector<TInput>, Vector<TOutput>)

Evaluates the student model's accuracy on a dataset.

public virtual double Evaluate(Func<TInput, TOutput> studentForward, Vector<TInput> inputs, Vector<TOutput> trueLabels)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass through student model.

inputs Vector<TInput>

Evaluation input data.

trueLabels Vector<TOutput>

True labels for evaluation.

Returns

double

Accuracy as a fraction (0-1).

Remarks

For Beginners: Evaluation measures how well the student performs: - For each input, get student's prediction - Compare with true label (argmax for classification) - Calculate fraction of correct predictions

This is used to monitor training progress and detect overfitting.

FisherYatesShuffle(int)

Generates a random permutation of indices using Fisher-Yates shuffle.

protected int[] FisherYatesShuffle(int length)

Parameters

length int

The length of the array to shuffle.

Returns

int[]

An array of shuffled indices.

GetTeacherPredictions(TInput, int)

Gets teacher predictions for a given input. Abstract method to be implemented by derived classes.

protected abstract TOutput GetTeacherPredictions(TInput input, int index)

Parameters

input TInput

The input data.

index int

The index of this input in the training batch (useful for caching).

Returns

TOutput

Teacher's output predictions.

Remarks

For Derived Classes: - Standard trainers: Call teacher.GetLogits(input) - Self-distillation: Return cached predictions from previous generation - Online distillation: Get predictions from dynamically updated teacher - Ensemble: Combine predictions from multiple teachers

IsCorrectPrediction(TOutput, TOutput)

Determines if a prediction matches the true label.

protected virtual bool IsCorrectPrediction(TOutput prediction, TOutput trueLabel)

Parameters

prediction TOutput

Student's prediction.

trueLabel TOutput

True label.

Returns

bool

True if prediction is correct.

Remarks

For Classification: Uses argmax to find predicted class and compares with true class. The conversion to Vector<T> is handled by ConversionsHelper, so this works for any TOutput type (Vector<T>, Tensor<T>, T[], or scalar T).

Override This: If you need different evaluation criteria (e.g., regression).

OnEpochEnd(int, T)

Called after each epoch completes. Override for custom per-epoch cleanup/logging.

protected virtual void OnEpochEnd(int epoch, T avgLoss)

Parameters

epoch int

Current epoch number (0-indexed).

avgLoss T

Average loss for this epoch.

Remarks

Use Cases: - Log epoch metrics - Save checkpoints - Update adaptive parameters - Implement early stopping logic

IMPORTANT: This base implementation calls Reset() on RelationalDistillationStrategy to flush partial batches and prevent buffer leakage between epochs. Derived classes should call base.OnEpochEnd() if they override this method.

Automatic Checkpointing: If checkpoint configuration was provided to the constructor, this method automatically saves checkpoints based on your configuration.

OnEpochStart(int, Vector<TInput>, Vector<TOutput>?)

Called before each epoch starts. Override for custom per-epoch initialization.

protected virtual void OnEpochStart(int epoch, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)

Parameters

epoch int

Current epoch number (0-indexed).

trainInputs Vector<TInput>

Training inputs.

trainLabels Vector<TOutput>

Training labels.

Remarks

Use Cases: - Update learning rate schedules - Adjust temperature schedules - Update curriculum difficulty - Refresh teacher in online distillation

OnTrainingEnd(Vector<TInput>, Vector<TOutput>?)

Called after training completes. Override for custom cleanup logic.

protected virtual void OnTrainingEnd(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)

Parameters

trainInputs Vector<TInput>

Training inputs.

trainLabels Vector<TOutput>

Training labels.

Remarks

Use Cases: - Clear caches - Save final checkpoints - Log final metrics - Free temporary resources

Automatic Checkpointing: If checkpoint configuration was provided to the constructor, this method automatically loads the best checkpoint (based on validation metrics) after training completes.

OnTrainingStart(Vector<TInput>, Vector<TOutput>?)

Called before training starts. Override for custom initialization logic.

protected virtual void OnTrainingStart(Vector<TInput> trainInputs, Vector<TOutput>? trainLabels)

Parameters

trainInputs Vector<TInput>

Training inputs.

trainLabels Vector<TOutput>

Training labels.

Remarks

Use Cases: - Cache teacher predictions (for efficiency) - Initialize EMA buffers (for self-distillation) - Setup curriculum schedules - Allocate temporary buffers

Automatic Checkpointing: If checkpoint configuration was provided to the constructor, this method automatically initializes the checkpoint manager.

OnValidationComplete(int, double)

Called after validation completes for an epoch. Override for custom validation handling.

protected virtual void OnValidationComplete(int epoch, double accuracy)

Parameters

epoch int

Current epoch number (0-indexed).

accuracy double

Validation accuracy (0-100).

Remarks

Use Cases: - Log validation metrics - Implement early stopping - Track best model - Adjust hyperparameters based on validation performance

Automatic Checkpointing: If CheckpointConfig is set, this method automatically tracks validation metrics for best checkpoint selection.

ShuffleData(Vector<TInput>, Vector<TOutput>?)

Shuffles training data using Fisher-Yates algorithm.

protected virtual (Vector<TInput> shuffledInputs, Vector<TOutput>? shuffledLabels) ShuffleData(Vector<TInput> inputs, Vector<TOutput>? labels)

Parameters

inputs Vector<TInput>

Input data to shuffle.

labels Vector<TOutput>

Labels to shuffle (maintains alignment with inputs).

Returns

(Vector<TInput> shuffledInputs, Vector<TOutput> shuffledLabels)

Tuple of shuffled inputs and labels.

Remarks

For Beginners: Data shuffling is important because: - Prevents model from learning order-dependent patterns - Improves gradient descent convergence - Reduces overfitting to batch ordering

Fisher-Yates is O(n) compared to O(n log n) for sort-based shuffling.

Train(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?, int, int, Vector<TInput>?, Vector<TOutput>?, ICheckpointableModel?, Action<int, T>?)

Trains the student model for multiple epochs using knowledge distillation.

public virtual void Train(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels = null, int epochs = 20, int batchSize = 32, Vector<TInput>? validationInputs = null, Vector<TOutput>? validationLabels = null, ICheckpointableModel? student = null, Action<int, T>? onEpochComplete = null)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass through student model.

studentBackward Action<TOutput>

Function to perform backward pass and update weights.

trainInputs Vector<TInput>

Training input data.

trainLabels Vector<TOutput>

Training labels (optional for pure soft distillation).

epochs int

Number of training epochs.

batchSize int

Batch size for mini-batch training.

validationInputs Vector<TInput>

Optional validation inputs for monitoring.

validationLabels Vector<TOutput>

Optional validation labels.

student ICheckpointableModel

Optional student model for automatic checkpointing (must implement ICheckpointableModel).

onEpochComplete Action<int, T>

Optional callback invoked after each epoch with (epoch, avgLoss).

Remarks

For Beginners: This method orchestrates the complete training process: 1. Initialize (OnTrainingStart) 2. For each epoch: a. Shuffle training data b. Process data in batches c. Validate if requested d. Invoke callbacks 3. Cleanup (OnTrainingEnd)

Training Tips: - Use batch sizes that fit in memory (32-128 typical) - Monitor validation loss to detect overfitting - Invoke callbacks to log progress or save checkpoints

Automatic Checkpointing: If a checkpoint configuration was provided to the constructor and you pass the student model parameter, automatic checkpointing will be enabled.

Train(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?, int, int, Action<int, T>?)

Trains the student model for multiple epochs (interface-compliant overload).

public virtual void Train(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels, int epochs, int batchSize = 32, Action<int, T>? onEpochComplete = null)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass on student model.

studentBackward Action<TOutput>

Function to perform backward pass on student model.

trainInputs Vector<TInput>

Training input data.

trainLabels Vector<TOutput>

Training labels (optional). If null, uses only teacher supervision.

epochs int

Number of epochs to train.

batchSize int

Batch size for training.

onEpochComplete Action<int, T>

Optional callback invoked after each epoch with (epoch, avgLoss).

Remarks

For Beginners: This overload matches the interface contract and delegates to the extended version with no validation data.

TrainBatch(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?)

Trains the student model on a single batch using knowledge distillation.

public virtual T TrainBatch(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> inputs, Vector<TOutput>? trueLabels = null)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass through student model.

studentBackward Action<TOutput>

Function to perform backward pass and update weights.

inputs Vector<TInput>

Batch of input data.

trueLabels Vector<TOutput>

Optional true labels for the batch.

Returns

T

Average loss for the batch.

Remarks

For Beginners: Training a batch involves: 1. Get student predictions (forward pass) 2. Get teacher predictions (from teacher model or cache) 3. Compute distillation loss (how different student is from teacher) 4. Compute gradients (how to improve student) 5. Update student weights (backward pass)