Table of Contents

Class TeacherModelBase<TInput, TOutput, T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Abstract base class for teacher models used in knowledge distillation. Provides common functionality and utilities for teacher model implementations.

public abstract class TeacherModelBase<TInput, TOutput, T> : ITeacherModel<TInput, TOutput>, IJitCompilable<T>

Type Parameters

TInput

The input data type (e.g., Vector, Matrix, Tensor).

TOutput

The output data type (typically logits as Vector or Matrix).

T

The numeric type for calculations (e.g., double, float).

Inheritance
TeacherModelBase<TInput, TOutput, T>
Implements
ITeacherModel<TInput, TOutput>
Derived
Inherited Members

Remarks

For Beginners: This base class provides common functionality that all teacher models need, such as numeric operations and input validation. It's a lightweight foundation that derived classes build upon.

Why use a base class? - **Code Reuse**: Common utilities like numeric operations are available to all implementations - **Consistency**: All teachers have access to the same helper methods - **Extensibility**: New teacher types inherit core functionality automatically - **Maintainability**: Updates to common utilities benefit all implementations

Architecture Note: This base class is intentionally minimal. Complex operations like temperature scaling are handled by distillation strategies, not teachers. Teachers are responsible only for providing raw logits.

Constructors

TeacherModelBase()

Initializes the base teacher model and sets up numeric operations.

protected TeacherModelBase()

Fields

NumOps

Numeric operations for the specified type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Remarks

Provides type-specific arithmetic operations (add, multiply, etc.) that work with any numeric type (double, float, decimal, etc.).

Properties

OutputDimension

Gets the number of output dimensions (e.g., number of classes for classification).

public abstract int OutputDimension { get; }

Property Value

int

Remarks

For Implementers: Return the size of the output vector. For example, a 10-class classifier should return 10.

SupportsJitCompilation

Gets whether this teacher model supports JIT compilation.

public abstract bool SupportsJitCompilation { get; }

Property Value

bool

true if the teacher model can be JIT compiled; otherwise, false.

Remarks

Teacher models that wrap other models should delegate to the wrapped model's JIT support. Teacher models using function delegates or cached predictions may not support JIT.

For Implementers: Return true if your teacher model can export its computation as a graph. Models wrapping IJitCompilable implementations should return the wrapped model's SupportsJitCompilation value.

Methods

ApplyTemperatureSoftmax(TOutput, double)

Applies temperature-scaled softmax to logits. Must be implemented by subclasses based on their output type (Vector, Matrix, etc.).

protected virtual TOutput ApplyTemperatureSoftmax(TOutput logits, double temperature)

Parameters

logits TOutput

Raw model outputs.

temperature double

Temperature for scaling.

Returns

TOutput

Probability distribution.

CheckWrappedModelJitSupport(ITeacherModel<TInput, TOutput>)

Checks if a wrapped teacher model supports JIT compilation.

protected static bool CheckWrappedModelJitSupport(ITeacherModel<TInput, TOutput> wrappedModel)

Parameters

wrappedModel ITeacherModel<TInput, TOutput>

The wrapped teacher model to check.

Returns

bool

true if the wrapped model implements IJitCompilable and supports JIT; otherwise, false.

Remarks

Use this helper method in derived classes that wrap another ITeacherModel to implement the SupportsJitCompilation property.

Example:

public override bool SupportsJitCompilation => CheckWrappedModelJitSupport(_baseTeacher);

DelegateJitExport(ITeacherModel<TInput, TOutput>, List<ComputationNode<T>>, string)

Delegates JIT compilation export to a wrapped teacher model.

protected static ComputationNode<T> DelegateJitExport(ITeacherModel<TInput, TOutput> wrappedModel, List<ComputationNode<T>> inputNodes, string wrapperTypeName)

Parameters

wrappedModel ITeacherModel<TInput, TOutput>

The wrapped teacher model to delegate to.

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

wrapperTypeName string

Name of the wrapper type (for error messages).

Returns

ComputationNode<T>

The output computation node from the wrapped model.

Remarks

Use this helper method in derived classes that wrap another ITeacherModel to implement the ExportComputationGraph method.

Example:

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
    => DelegateJitExport(_baseTeacher, inputNodes, nameof(AdaptiveTeacherModel<T>));

Exceptions

NotSupportedException

Thrown when the wrapped model does not implement IJitCompilable or does not support JIT.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the teacher model's computation graph for JIT compilation.

public abstract ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the teacher's logits.

Remarks

For teacher models that wrap other models, this should delegate to the wrapped model's ExportComputationGraph method. For models using function delegates, this may not be supported and should throw NotSupportedException.

For Implementers: If your teacher wraps a model implementing IJitCompilable, delegate to that model's ExportComputationGraph. Otherwise, implement the computation graph directly or throw NotSupportedException with a clear explanation.

Exceptions

NotSupportedException

Thrown when the teacher model does not support JIT compilation.

GetLogits(TInput)

Gets the teacher's raw logits (pre-softmax outputs) for the given input.

public abstract TOutput GetLogits(TInput input)

Parameters

input TInput

The input data to process.

Returns

TOutput

Raw logits before applying softmax.

Remarks

For Implementers: Override this method to extract logits from your specific model type. Ensure you return pre-activation outputs, not probabilities.

Important: Temperature scaling and softmax conversion are handled by the distillation strategy, not by the teacher. Just return raw logits here.

Softmax(Vector<T>, double)

protected virtual Vector<T> Softmax(Vector<T> logits, double temperature = 1)

Parameters

logits Vector<T>
temperature double

Returns

Vector<T>

ThrowJitNotSupported(string, string)

Throws a standardized NotSupportedException for teacher models that cannot support JIT compilation.

protected static ComputationNode<T> ThrowJitNotSupported(string teacherTypeName, string reason)

Parameters

teacherTypeName string

Name of the teacher type.

reason string

Reason why JIT is not supported.

Returns

ComputationNode<T>

Never returns (always throws).

Remarks

Use this helper method in derived classes that cannot support JIT compilation.

Example:

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
    => ThrowJitNotSupported(nameof(PretrainedTeacherModel<T>),
        "it uses a function delegate which cannot be exported as a computation graph");

Exceptions

NotSupportedException

Always thrown.

ValidateInput(TInput?, string)

Validates that the input is not null.

protected void ValidateInput(TInput? input, string paramName = "input")

Parameters

input TInput

Input to validate.

paramName string

Parameter name for exception message.

Remarks

Helper method for derived classes to validate inputs before processing.