Class TransformerTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Transformer-based teacher model that provides logits from transformer architectures.
public class TransformerTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
TransformerTeacherModel<T>
- Implements
- Inherited Members
Remarks
Architecture Note: This class supports two construction modes:
- Function delegate mode: Uses a Func<> for forward pass (not JIT-compilable)
- IJitCompilable mode: Uses a JIT-compilable model for forward pass (JIT-compilable)
For attention-based distillation strategies that need attention weights, implement a custom IDistillationStrategy that can extract attention from the underlying model.
Constructors
TransformerTeacherModel(IJitCompilable<T>, int, int)
Initializes a new instance of the TransformerTeacherModel class using a JIT-compilable model.
public TransformerTeacherModel(IJitCompilable<T> jitCompilableModel, int inputDimension, int outputDimension)
Parameters
jitCompilableModelIJitCompilable<T>A JIT-compilable model that performs forward pass.
inputDimensionintThe number of input dimensions.
outputDimensionintThe number of output dimensions.
Remarks
JIT Support: This constructor enables JIT compilation when the underlying model supports it. Use this constructor for optimal inference performance.
Exceptions
- ArgumentNullException
Thrown when jitCompilableModel is null.
- ArgumentOutOfRangeException
Thrown when dimensions are not positive.
TransformerTeacherModel(Func<Vector<T>, Vector<T>>, int, int)
Initializes a new instance of the TransformerTeacherModel class using a function delegate.
public TransformerTeacherModel(Func<Vector<T>, Vector<T>> forwardFunc, int inputDimension, int outputDimension)
Parameters
forwardFuncFunc<Vector<T>, Vector<T>>Function that performs forward pass and returns logits.
inputDimensionintThe number of input dimensions.
outputDimensionintThe number of output dimensions.
Remarks
Note: This constructor creates a non-JIT-compilable teacher. For JIT support, use the constructor that accepts an IJitCompilable model.
Exceptions
- ArgumentNullException
Thrown when forwardFunc is null.
- ArgumentOutOfRangeException
Thrown when dimensions are not positive.
Properties
OutputDimension
Gets the output dimension.
public override int OutputDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif constructed with a JIT-compilable model that supports JIT; otherwise,false.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node.
Exceptions
- NotSupportedException
Thrown when constructed with a function delegate instead of a JIT-compilable model.
GetLogits(Vector<T>)
Gets logits from the transformer model.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>The input data.
Returns
- Vector<T>
Raw logits from the transformer.