Class CurriculumTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Curriculum teacher that wraps a base teacher for curriculum learning scenarios.
public class CurriculumTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
CurriculumTeacherModel<T>
- Implements
- Inherited Members
Remarks
Architecture Note: This class provides a simple wrapper around a base teacher. Curriculum learning logic (adjusting difficulty over time) should be implemented in the training loop or distillation strategy, not in the teacher model.
The teacher model's responsibility is only to provide predictions (logits). Curriculum decisions (which samples to show, how to adjust temperature/alpha) belong in the strategy or trainer layer.
Constructors
CurriculumTeacherModel(ITeacherModel<Vector<T>, Vector<T>>)
Initializes a new instance of the CurriculumTeacherModel class.
public CurriculumTeacherModel(ITeacherModel<Vector<T>, Vector<T>> baseTeacher)
Parameters
baseTeacherITeacherModel<Vector<T>, Vector<T>>The underlying teacher model.
Properties
OutputDimension
Gets the output dimension from the base teacher.
public override int OutputDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif the base teacher implements IJitCompilable and supports JIT; otherwise,false.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph by delegating to the base teacher.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node from the base teacher.
Exceptions
- NotSupportedException
Thrown when the base teacher does not support JIT compilation.
GetLogits(Vector<T>)
Gets logits from the base teacher.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>The input data.
Returns
- Vector<T>
Raw logits from the base teacher.