Class AdaptiveTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Adaptive teacher model that wraps a base teacher and provides its logits.
public class AdaptiveTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
AdaptiveTeacherModel<T>
- Implements
- Inherited Members
Remarks
Architecture Note: This class has been simplified to match the current architecture where temperature scaling is handled by distillation strategies, not teachers. The adaptive features (dynamic temperature adjustment based on student performance) have been removed as they belong in the strategy layer.
For adaptive temperature scaling, implement a custom IDistillationStrategy that monitors student performance and adjusts temperature accordingly.
Constructors
AdaptiveTeacherModel(ITeacherModel<Vector<T>, Vector<T>>)
Initializes a new instance of the AdaptiveTeacherModel class.
public AdaptiveTeacherModel(ITeacherModel<Vector<T>, Vector<T>> baseTeacher)
Parameters
baseTeacherITeacherModel<Vector<T>, Vector<T>>The underlying teacher model.
Properties
OutputDimension
Gets the output dimension.
public override int OutputDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif the base teacher implements IJitCompilable and supports JIT; otherwise,false.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph by delegating to the base teacher.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node from the base teacher.
Exceptions
- NotSupportedException
Thrown when the base teacher does not support JIT compilation.
GetLogits(Vector<T>)
Gets logits from the base teacher.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>