Class FakeQuantizationOp
- Namespace
- AiDotNet.JitCompiler.IR.Operations
- Assembly
- AiDotNet.dll
Represents a fake quantization operation with Straight-Through Estimator (STE).
public class FakeQuantizationOp : IROp
- Inheritance
-
FakeQuantizationOp
- Inherited Members
Remarks
Implements differentiable quantization using the Straight-Through Estimator (STE) approach. The forward pass applies quantization, while the backward pass passes gradients through unchanged. This enables training quantization-aware models and JIT compilation of quantized inference.
The operation computes:
Forward: output = round(input / scale) * scale
Backward: ∂L/∂input = ∂L/∂output (gradient passes through)
For Beginners: Quantization reduces precision (e.g., from 32-bit to 8-bit) to make models smaller and faster. The challenge is that rounding isn't differentiable.
Fake quantization solves this by:
- Forward pass: Actually quantize the values (round to discrete levels)
- Backward pass: Pretend quantization didn't happen (let gradients flow through)
This trick (Straight-Through Estimator) lets us train models that will be quantized later.
Properties
NumBits
Gets or sets the number of quantization bits.
public int NumBits { get; set; }
Property Value
Scale
Gets or sets the scale factor for quantization. If not specified, it will be computed from min/max values.
public double? Scale { get; set; }
Property Value
Symmetric
Gets or sets whether to use symmetric quantization.
public bool Symmetric { get; set; }
Property Value
ZeroPoint
Gets or sets the zero point for asymmetric quantization.
public double ZeroPoint { get; set; }
Property Value
Methods
ToString()
Gets a string representation of this operation for debugging.
public override string ToString()
Returns
- string
A string describing this operation.
Remarks
The string format is: "tOutput = OpType(tInput1, tInput2, ...) : Type [Shape]"
For Beginners: This creates a readable description of the operation.
Example outputs:
- "t2 = Add(t0, t1) : Float32 [3, 4]"
- "t5 = MatMul(t3, t4) : Float32 [128, 256]"
- "t8 = ReLU(t7) : Float32 [32, 128]"
This is super helpful for debugging - you can see exactly what each operation does and what shape tensors flow through the graph.
Validate()
Validates that this operation is correctly formed.
public override bool Validate()
Returns
- bool
True if valid, false otherwise.
Remarks
Basic validation checks that the operation has required information. Derived classes can override to add operation-specific validation.
For Beginners: This checks that the operation makes sense.
Basic checks:
- Output ID is valid (non-negative)
- Has the right number of inputs
- Shapes are compatible
Specific operations add their own checks:
- MatMul: inner dimensions must match
- Conv2D: kernel size must be valid
- Reshape: total elements must be preserved
If validation fails, the operation can't be compiled.