Class PretrainedTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Pretrained teacher model from external source (e.g., ImageNet, BERT).
public class PretrainedTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
T
- Inheritance
-
PretrainedTeacherModel<T>
- Implements
- Inherited Members
Remarks
Architecture Note: This class supports two construction modes:
- Function delegate mode: Uses a Func<> for forward pass (not JIT-compilable)
- IJitCompilable mode: Uses a JIT-compilable model for forward pass (JIT-compilable)
Constructors
PretrainedTeacherModel(IJitCompilable<T>, int, int)
Initializes a new instance using a JIT-compilable model.
public PretrainedTeacherModel(IJitCompilable<T> jitCompilableModel, int inputDimension, int outputDimension)
Parameters
jitCompilableModelIJitCompilable<T>A JIT-compilable model for forward pass.
inputDimensionintThe number of input dimensions.
outputDimensionintThe number of output dimensions.
Remarks
JIT Support: This constructor enables JIT compilation when the underlying model supports it. Use this constructor for optimal inference performance.
PretrainedTeacherModel(Func<Vector<T>, Vector<T>>, int, int)
Initializes a new instance using a function delegate (not JIT-compilable).
public PretrainedTeacherModel(Func<Vector<T>, Vector<T>> pretrainedForward, int inputDimension, int outputDimension)
Parameters
pretrainedForwardFunc<Vector<T>, Vector<T>>Function that performs forward pass.
inputDimensionintThe number of input dimensions.
outputDimensionintThe number of output dimensions.
Properties
OutputDimension
Gets the number of output dimensions (e.g., number of classes for classification).
public override int OutputDimension { get; }
Property Value
Remarks
For Implementers: Return the size of the output vector. For example, a 10-class classifier should return 10.
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif constructed with a JIT-compilable model that supports JIT; otherwise,false.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node.
Exceptions
- NotSupportedException
Thrown when constructed with a function delegate instead of a JIT-compilable model.
GetLogits(Vector<T>)
Gets logits from the pretrained model.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
Remarks
Architecture Note: Returns raw logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.