Table of Contents

Interface IKnowledgeDistillationTrainer<T, TInput, TOutput>

Namespace
AiDotNet.Interfaces
Assembly
AiDotNet.dll

Defines the contract for knowledge distillation trainers that train student models using knowledge transferred from teacher models.

public interface IKnowledgeDistillationTrainer<T, TInput, TOutput>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

TInput

The input data type (e.g., Matrix<T> for tabular data, Tensor<T> for images).

TOutput

The output data type (e.g., Matrix<T> for batch outputs, Tensor<T> for structured outputs).

Remarks

For Beginners: A knowledge distillation trainer orchestrates the process of transferring knowledge from a large, accurate teacher model to a smaller, faster student model.

Why an interface? - **Flexibility**: Multiple trainer implementations (standard, self-distillation, multi-teacher, etc.) - **Testability**: Easy to mock for unit testing - **Extensibility**: New training strategies can be added without breaking existing code - **Dependency Injection**: Can be injected into other components

Common Implementations: - **Standard Trainer**: Single teacher → single student - **Self-Distillation Trainer**: Model teaches itself (improves calibration) - **Multi-Teacher Trainer**: Multiple teachers → one student (ensemble distillation) - **Online Trainer**: Teacher updates during student training - **Mutual Learning Trainer**: Multiple students learn from each other

Properties

DistillationStrategy

Gets the distillation strategy used for computing loss and gradients.

IDistillationStrategy<T> DistillationStrategy { get; }

Property Value

IDistillationStrategy<T>

Teacher

Gets the teacher model used for distillation.

ITeacherModel<TInput, TOutput> Teacher { get; }

Property Value

ITeacherModel<TInput, TOutput>

Methods

Evaluate(Func<TInput, TOutput>, Vector<TInput>, Vector<TOutput>)

Evaluates the student model's accuracy on test data.

double Evaluate(Func<TInput, TOutput> studentForward, Vector<TInput> testInputs, Vector<TOutput> testLabels)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass on student model.

testInputs Vector<TInput>

Test input data.

testLabels Vector<TOutput>

Test labels (one-hot encoded).

Returns

double

Classification accuracy as a percentage (0-100).

Remarks

For Beginners: This measures how well the student has learned. Returns a percentage between 0 and 100, where 100.0 means perfect accuracy.

Train(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?, int, int, Action<int, T>?)

Trains the student model for multiple epochs.

void Train(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> trainInputs, Vector<TOutput>? trainLabels, int epochs, int batchSize = 32, Action<int, T>? onEpochComplete = null)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass on student model.

studentBackward Action<TOutput>

Function to perform backward pass on student model.

trainInputs Vector<TInput>

Training input data.

trainLabels Vector<TOutput>

Training labels (optional). If null, uses only teacher supervision.

epochs int

Number of epochs to train.

batchSize int

Batch size for training.

onEpochComplete Action<int, T>

Optional callback invoked after each epoch with (epoch, avgLoss).

Remarks

For Beginners: An epoch is one complete pass through the training data. This method runs the complete training process for the specified number of epochs.

TrainBatch(Func<TInput, TOutput>, Action<TOutput>, Vector<TInput>, Vector<TOutput>?)

Trains the student model on a single batch of data.

T TrainBatch(Func<TInput, TOutput> studentForward, Action<TOutput> studentBackward, Vector<TInput> inputs, Vector<TOutput>? trueLabels = null)

Parameters

studentForward Func<TInput, TOutput>

Function to perform forward pass on student model.

studentBackward Action<TOutput>

Function to perform backward pass on student model (takes gradient).

inputs Vector<TInput>

Input batch.

trueLabels Vector<TOutput>

Ground truth labels (optional). If null, uses only teacher supervision.

Returns

T

Average loss for the batch.

Remarks

For Beginners: This method processes one batch of training data: 1. Get predictions from both teacher and student 2. Compute distillation loss 3. Update student weights via backpropagation