Class TeacherModelBase<TInput, TOutput, T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Abstract base class for teacher models used in knowledge distillation. Provides common functionality and utilities for teacher model implementations.
public abstract class TeacherModelBase<TInput, TOutput, T> : ITeacherModel<TInput, TOutput>, IJitCompilable<T>
Type Parameters
TInputThe input data type (e.g., Vector, Matrix, Tensor).
TOutputThe output data type (typically logits as Vector or Matrix).
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
TeacherModelBase<TInput, TOutput, T>
- Implements
-
ITeacherModel<TInput, TOutput>
- Derived
- Inherited Members
Remarks
For Beginners: This base class provides common functionality that all teacher models need, such as numeric operations and input validation. It's a lightweight foundation that derived classes build upon.
Why use a base class? - **Code Reuse**: Common utilities like numeric operations are available to all implementations - **Consistency**: All teachers have access to the same helper methods - **Extensibility**: New teacher types inherit core functionality automatically - **Maintainability**: Updates to common utilities benefit all implementations
Architecture Note: This base class is intentionally minimal. Complex operations like temperature scaling are handled by distillation strategies, not teachers. Teachers are responsible only for providing raw logits.
Constructors
TeacherModelBase()
Initializes the base teacher model and sets up numeric operations.
protected TeacherModelBase()
Fields
NumOps
Numeric operations for the specified type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Remarks
Provides type-specific arithmetic operations (add, multiply, etc.) that work with any numeric type (double, float, decimal, etc.).
Properties
OutputDimension
Gets the number of output dimensions (e.g., number of classes for classification).
public abstract int OutputDimension { get; }
Property Value
Remarks
For Implementers: Return the size of the output vector. For example, a 10-class classifier should return 10.
SupportsJitCompilation
Gets whether this teacher model supports JIT compilation.
public abstract bool SupportsJitCompilation { get; }
Property Value
- bool
trueif the teacher model can be JIT compiled; otherwise,false.
Remarks
Teacher models that wrap other models should delegate to the wrapped model's JIT support. Teacher models using function delegates or cached predictions may not support JIT.
For Implementers: Return true if your teacher model can export its
computation as a graph. Models wrapping IJitCompilable implementations should return
the wrapped model's SupportsJitCompilation value.
Methods
ApplyTemperatureSoftmax(TOutput, double)
Applies temperature-scaled softmax to logits. Must be implemented by subclasses based on their output type (Vector, Matrix, etc.).
protected virtual TOutput ApplyTemperatureSoftmax(TOutput logits, double temperature)
Parameters
logitsTOutputRaw model outputs.
temperaturedoubleTemperature for scaling.
Returns
- TOutput
Probability distribution.
CheckWrappedModelJitSupport(ITeacherModel<TInput, TOutput>)
Checks if a wrapped teacher model supports JIT compilation.
protected static bool CheckWrappedModelJitSupport(ITeacherModel<TInput, TOutput> wrappedModel)
Parameters
wrappedModelITeacherModel<TInput, TOutput>The wrapped teacher model to check.
Returns
- bool
trueif the wrapped model implements IJitCompilable and supports JIT; otherwise,false.
Remarks
Use this helper method in derived classes that wrap another ITeacherModel to implement the SupportsJitCompilation property.
Example:
public override bool SupportsJitCompilation => CheckWrappedModelJitSupport(_baseTeacher);
DelegateJitExport(ITeacherModel<TInput, TOutput>, List<ComputationNode<T>>, string)
Delegates JIT compilation export to a wrapped teacher model.
protected static ComputationNode<T> DelegateJitExport(ITeacherModel<TInput, TOutput> wrappedModel, List<ComputationNode<T>> inputNodes, string wrapperTypeName)
Parameters
wrappedModelITeacherModel<TInput, TOutput>The wrapped teacher model to delegate to.
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
wrapperTypeNamestringName of the wrapper type (for error messages).
Returns
- ComputationNode<T>
The output computation node from the wrapped model.
Remarks
Use this helper method in derived classes that wrap another ITeacherModel to implement the ExportComputationGraph method.
Example:
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
=> DelegateJitExport(_baseTeacher, inputNodes, nameof(AdaptiveTeacherModel<T>));
Exceptions
- NotSupportedException
Thrown when the wrapped model does not implement IJitCompilable or does not support JIT.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the teacher model's computation graph for JIT compilation.
public abstract ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the teacher's logits.
Remarks
For teacher models that wrap other models, this should delegate to the wrapped model's ExportComputationGraph method. For models using function delegates, this may not be supported and should throw NotSupportedException.
For Implementers: If your teacher wraps a model implementing IJitCompilable, delegate to that model's ExportComputationGraph. Otherwise, implement the computation graph directly or throw NotSupportedException with a clear explanation.
Exceptions
- NotSupportedException
Thrown when the teacher model does not support JIT compilation.
GetLogits(TInput)
Gets the teacher's raw logits (pre-softmax outputs) for the given input.
public abstract TOutput GetLogits(TInput input)
Parameters
inputTInputThe input data to process.
Returns
- TOutput
Raw logits before applying softmax.
Remarks
For Implementers: Override this method to extract logits from your specific model type. Ensure you return pre-activation outputs, not probabilities.
Important: Temperature scaling and softmax conversion are handled by the distillation strategy, not by the teacher. Just return raw logits here.
Softmax(Vector<T>, double)
protected virtual Vector<T> Softmax(Vector<T> logits, double temperature = 1)
Parameters
logitsVector<T>temperaturedouble
Returns
- Vector<T>
ThrowJitNotSupported(string, string)
Throws a standardized NotSupportedException for teacher models that cannot support JIT compilation.
protected static ComputationNode<T> ThrowJitNotSupported(string teacherTypeName, string reason)
Parameters
Returns
- ComputationNode<T>
Never returns (always throws).
Remarks
Use this helper method in derived classes that cannot support JIT compilation.
Example:
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
=> ThrowJitNotSupported(nameof(PretrainedTeacherModel<T>),
"it uses a function delegate which cannot be exported as a computation graph");
Exceptions
- NotSupportedException
Always thrown.
ValidateInput(TInput?, string)
Validates that the input is not null.
protected void ValidateInput(TInput? input, string paramName = "input")
Parameters
inputTInputInput to validate.
paramNamestringParameter name for exception message.
Remarks
Helper method for derived classes to validate inputs before processing.