Class TeacherModelWrapper<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Wraps an existing trained IFullModel to act as a teacher for knowledge distillation.
public class TeacherModelWrapper<T> : ITeacherModel<Vector<T>, Vector<T>>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
TeacherModelWrapper<T>
- Implements
- Inherited Members
Remarks
For Beginners: This class takes any trained IFullModel and adapts it to work as a teacher in knowledge distillation. The teacher model should already be trained and perform well on your task.
Architecture Note: This is a lightweight adapter that bridges IFullModel to ITeacherModel. It simply delegates GetLogits() to the underlying model's Predict() method, since in this architecture, predictions and logits are equivalent.
Real-world Example: Imagine you have a large, accurate neural network trained on your dataset. You can wrap it with TeacherModelWrapper and use it to train a smaller, faster student model that retains most of the accuracy but runs much faster.
Common teacher-student scenarios: - Large neural network (teacher) → Smaller network (student): 40-60% smaller, 95-97% of performance - Deep network (teacher) → Shallow network (student): 10x faster inference - Ensemble (teacher) → Single model (student): Deployable on resource-constrained devices
Constructors
TeacherModelWrapper(Func<Vector<T>, Vector<T>>, int)
Initializes a new instance of the TeacherModelWrapper class from a forward function.
public TeacherModelWrapper(Func<Vector<T>, Vector<T>> forwardFunc, int outputDimension)
Parameters
forwardFuncFunc<Vector<T>, Vector<T>>Function that performs forward pass and returns logits.
outputDimensionintThe number of output dimensions (classes).
Remarks
For Beginners: This constructor lets you create a teacher from any prediction function. The forward function should take input and return logits (raw outputs).
Example usage:
var teacher = new TeacherModelWrapper<double>(
forwardFunc: input => myTrainedModel.Predict(input),
outputDimension: 10 // 10 classes (e.g., CIFAR-10)
);
Properties
OutputDimension
Gets the number of output dimensions (e.g., number of classes for classification).
public int OutputDimension { get; }
Property Value
Methods
GetLogits(Vector<T>)
Gets the teacher's raw logits (pre-softmax outputs) for the given input.
public Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>The input data to process.
Returns
- Vector<T>
Raw logits before applying softmax.
Remarks
For Beginners: Logits are the raw numerical outputs from a neural network before converting them to probabilities. They're preferred for distillation because: 1. They preserve more information than probabilities 2. They're numerically more stable 3. Temperature scaling works better on logits
Architecture Note: This method simply delegates to the wrapped model's Predict() method. In this architecture, predictions and logits are equivalent.