Class QuantizedTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Quantized teacher model with reduced precision for efficient deployment.
public class QuantizedTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
T
- Inheritance
-
QuantizedTeacherModel<T>
- Implements
- Inherited Members
Remarks
For Beginners: Quantization reduces the numerical precision of model weights and activations to use fewer bits (e.g., 8-bit instead of 32-bit floating point). This enables:
- Smaller model size
- Faster inference on hardware with integer support
- Reduced memory bandwidth requirements
JIT Support: When constructed with an IJitCompilable base model, this teacher supports JIT compilation using FakeQuantization with Straight-Through Estimator (STE). This allows the quantized model to be differentiated during training while simulating quantization effects.
Constructors
QuantizedTeacherModel(IJitCompilable<T>, int, int, T?, T?, bool)
Initializes a new instance of QuantizedTeacherModel wrapping a JIT-compilable model.
public QuantizedTeacherModel(IJitCompilable<T> jitCompilableBase, int outputDimension, int quantizationBits = 8, T? scale = default, T? zeroPoint = default, bool symmetric = true)
Parameters
jitCompilableBaseIJitCompilable<T>The JIT-compilable base model to quantize.
outputDimensionintOutput dimension of the model.
quantizationBitsintNumber of bits for quantization (1-32).
scaleTScale factor for quantization. If default, uses 1/(2^(bits-1)).
zeroPointTZero point for asymmetric quantization. Default is 0.
symmetricboolWhether to use symmetric quantization (centered at 0).
Remarks
JIT Support: This constructor enables JIT compilation using FakeQuantization with Straight-Through Estimator (STE). The scale and zero point are fixed at construction time, allowing the graph to be statically compiled.
Symmetric vs Asymmetric:
- Symmetric: Range is [-max, max], zero point is 0. Good for weights.
- Asymmetric: Range is [min, max], zero point may be non-zero. Good for activations with bias.
QuantizedTeacherModel(ITeacherModel<Vector<T>, Vector<T>>, int)
Initializes a new instance of QuantizedTeacherModel wrapping a teacher interface.
public QuantizedTeacherModel(ITeacherModel<Vector<T>, Vector<T>> baseTeacher, int quantizationBits = 8)
Parameters
baseTeacherITeacherModel<Vector<T>, Vector<T>>The base teacher model to quantize.
quantizationBitsintNumber of bits for quantization (1-32).
Remarks
This constructor uses dynamic quantization (per-batch min/max finding) which does not support JIT compilation. Use the constructor with IJitCompilable for JIT support.
Properties
OutputDimension
Gets the output dimension of the teacher model.
public override int OutputDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif constructed with an IJitCompilable model that supports JIT;falseif using dynamic quantization with runtime min/max finding.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph for JIT compilation with FakeQuantization.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input nodes.
Returns
- ComputationNode<T>
The output computation node with quantization applied.
Remarks
When constructed with an IJitCompilable model, this method exports the base model's computation graph and wraps the output with a FakeQuantization operation. The FakeQuantization uses Straight-Through Estimator (STE) for gradients, allowing backpropagation through the quantization operation.
When using dynamic quantization (per-batch min/max), JIT compilation is not supported because the quantization parameters are computed at runtime.
Exceptions
- NotSupportedException
Thrown when using dynamic quantization mode.
GetLogits(Vector<T>)
Gets quantized logits from the teacher model.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>Input to the model.
Returns
- Vector<T>
Quantized logits.