Table of Contents

Class TeacherModelWrapper<T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Wraps an existing trained IFullModel to act as a teacher for knowledge distillation.

public class TeacherModelWrapper<T> : ITeacherModel<Vector<T>, Vector<T>>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
TeacherModelWrapper<T>
Implements
ITeacherModel<Vector<T>, Vector<T>>
Inherited Members

Remarks

For Beginners: This class takes any trained IFullModel and adapts it to work as a teacher in knowledge distillation. The teacher model should already be trained and perform well on your task.

Architecture Note: This is a lightweight adapter that bridges IFullModel to ITeacherModel. It simply delegates GetLogits() to the underlying model's Predict() method, since in this architecture, predictions and logits are equivalent.

Real-world Example: Imagine you have a large, accurate neural network trained on your dataset. You can wrap it with TeacherModelWrapper and use it to train a smaller, faster student model that retains most of the accuracy but runs much faster.

Common teacher-student scenarios: - Large neural network (teacher) → Smaller network (student): 40-60% smaller, 95-97% of performance - Deep network (teacher) → Shallow network (student): 10x faster inference - Ensemble (teacher) → Single model (student): Deployable on resource-constrained devices

Constructors

TeacherModelWrapper(Func<Vector<T>, Vector<T>>, int)

Initializes a new instance of the TeacherModelWrapper class from a forward function.

public TeacherModelWrapper(Func<Vector<T>, Vector<T>> forwardFunc, int outputDimension)

Parameters

forwardFunc Func<Vector<T>, Vector<T>>

Function that performs forward pass and returns logits.

outputDimension int

The number of output dimensions (classes).

Remarks

For Beginners: This constructor lets you create a teacher from any prediction function. The forward function should take input and return logits (raw outputs).

Example usage:

var teacher = new TeacherModelWrapper<double>(
    forwardFunc: input => myTrainedModel.Predict(input),
    outputDimension: 10 // 10 classes (e.g., CIFAR-10)
);

Properties

OutputDimension

Gets the number of output dimensions (e.g., number of classes for classification).

public int OutputDimension { get; }

Property Value

int

Methods

GetLogits(Vector<T>)

Gets the teacher's raw logits (pre-softmax outputs) for the given input.

public Vector<T> GetLogits(Vector<T> input)

Parameters

input Vector<T>

The input data to process.

Returns

Vector<T>

Raw logits before applying softmax.

Remarks

For Beginners: Logits are the raw numerical outputs from a neural network before converting them to probabilities. They're preferred for distillation because: 1. They preserve more information than probabilities 2. They're numerically more stable 3. Temperature scaling works better on logits

Architecture Note: This method simply delegates to the wrapped model's Predict() method. In this architecture, predictions and logits are equivalent.