Table of Contents

Class TransformerTeacherModel<T>

Namespace
AiDotNet.KnowledgeDistillation.Teachers
Assembly
AiDotNet.dll

Transformer-based teacher model that provides logits from transformer architectures.

public class TransformerTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
TeacherModelBase<Vector<T>, Vector<T>, T>
TransformerTeacherModel<T>
Implements
ITeacherModel<Vector<T>, Vector<T>>
Inherited Members

Remarks

Architecture Note: This class supports two construction modes:

  • Function delegate mode: Uses a Func<> for forward pass (not JIT-compilable)
  • IJitCompilable mode: Uses a JIT-compilable model for forward pass (JIT-compilable)

For attention-based distillation strategies that need attention weights, implement a custom IDistillationStrategy that can extract attention from the underlying model.

Constructors

TransformerTeacherModel(IJitCompilable<T>, int, int)

Initializes a new instance of the TransformerTeacherModel class using a JIT-compilable model.

public TransformerTeacherModel(IJitCompilable<T> jitCompilableModel, int inputDimension, int outputDimension)

Parameters

jitCompilableModel IJitCompilable<T>

A JIT-compilable model that performs forward pass.

inputDimension int

The number of input dimensions.

outputDimension int

The number of output dimensions.

Remarks

JIT Support: This constructor enables JIT compilation when the underlying model supports it. Use this constructor for optimal inference performance.

Exceptions

ArgumentNullException

Thrown when jitCompilableModel is null.

ArgumentOutOfRangeException

Thrown when dimensions are not positive.

TransformerTeacherModel(Func<Vector<T>, Vector<T>>, int, int)

Initializes a new instance of the TransformerTeacherModel class using a function delegate.

public TransformerTeacherModel(Func<Vector<T>, Vector<T>> forwardFunc, int inputDimension, int outputDimension)

Parameters

forwardFunc Func<Vector<T>, Vector<T>>

Function that performs forward pass and returns logits.

inputDimension int

The number of input dimensions.

outputDimension int

The number of output dimensions.

Remarks

Note: This constructor creates a non-JIT-compilable teacher. For JIT support, use the constructor that accepts an IJitCompilable model.

Exceptions

ArgumentNullException

Thrown when forwardFunc is null.

ArgumentOutOfRangeException

Thrown when dimensions are not positive.

Properties

OutputDimension

Gets the output dimension.

public override int OutputDimension { get; }

Property Value

int

SupportsJitCompilation

Gets whether this teacher supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

true if constructed with a JIT-compilable model that supports JIT; otherwise, false.

Methods

ExportComputationGraph(List<ComputationNode<T>>)

Exports the computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node.

Exceptions

NotSupportedException

Thrown when constructed with a function delegate instead of a JIT-compilable model.

GetLogits(Vector<T>)

Gets logits from the transformer model.

public override Vector<T> GetLogits(Vector<T> input)

Parameters

input Vector<T>

The input data.

Returns

Vector<T>

Raw logits from the transformer.