Class DistillationStrategyBase<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Abstract base class for knowledge distillation strategies. Provides common functionality for computing losses and gradients in student-teacher training.
public abstract class DistillationStrategyBase<T> : IDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
DistillationStrategyBase<T>
- Implements
- Derived
- Inherited Members
Remarks
For Beginners: A distillation strategy defines how to measure the difference between student and teacher predictions. This base class provides common functionality that all strategies need, like temperature and alpha parameters.
Design Philosophy: Different distillation strategies focus on different aspects: - **Response-based**: Match final outputs (logits/probabilities) - **Feature-based**: Match intermediate layer representations - **Relation-based**: Match relationships between samples - **Attention-based**: Match attention patterns (for transformers)
This base class ensures all strategies handle temperature and alpha consistently, while allowing flexibility in how loss is computed.
Batch Processing: All strategies now operate on batches (Matrix<T>) for efficiency. Each row in the matrices represents one sample in the batch.
Template Method Pattern: The base class defines the structure (properties, validation), and subclasses implement the specifics (loss computation logic).
Constructors
DistillationStrategyBase(double, double)
Initializes a new instance of the distillation strategy base class.
protected DistillationStrategyBase(double temperature = 3, double alpha = 0.3)
Parameters
temperaturedoubleSoftmax temperature (default 3.0).
alphadoubleBalance between hard and soft loss (default 0.3).
Fields
Epsilon
Gets the epsilon value for numerical stability (to avoid log(0), division by zero, etc.).
protected const double Epsilon = 1E-10
Field Value
NumOps
Numeric operations for the specified type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Properties
Alpha
Gets or sets the balance parameter between hard loss and soft loss.
public double Alpha { get; set; }
Property Value
Remarks
For Beginners: Alpha controls the trade-off between learning from true labels (hard loss) and learning from the teacher (soft loss).
Alpha values: - α = 0: Only learn from teacher (pure distillation) - α = 0.3-0.5: Balanced (recommended) - α = 1: Only learn from labels (no distillation)
Validation: Must be between 0 and 1. Setting invalid values throws ArgumentException.
Temperature
Gets or sets the temperature parameter for softening probability distributions.
public double Temperature { get; set; }
Property Value
Remarks
For Beginners: Higher temperature makes predictions "softer", revealing more about the model's uncertainty and class relationships.
Temperature effects: - T = 1: Standard predictions (sharp) - T = 2-5: Softer predictions (recommended for distillation) - T > 10: Very soft (may be too smooth)
Validation: Must be positive (> 0). Setting invalid values throws ArgumentException.
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the distillation loss for backpropagation.
public abstract Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- Matrix<T>
The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]
Remarks
For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.
Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the distillation loss between student and teacher batch outputs.
public abstract T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- T
The computed distillation loss value (scalar) for the batch.
Remarks
For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.
Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.
ValidateLabelDimensions(Matrix<T>, Matrix<T>?)
Validates that batch outputs and labels have matching dimensions (if labels are provided).
protected void ValidateLabelDimensions(Matrix<T> batchOutput, Matrix<T>? labelsBatch)
Parameters
batchOutputMatrix<T>Batch output to validate.
labelsBatchMatrix<T>Labels batch to validate (can be null).
Remarks
If labels are null, validation is skipped (for pure soft distillation without labels).
Exceptions
- ArgumentException
Thrown when dimensions don't match.
ValidateOutputDimensions(Matrix<T>, Matrix<T>)
Validates that student and teacher batch outputs have matching dimensions.
protected void ValidateOutputDimensions(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput)
Parameters
studentBatchOutputMatrix<T>Student batch output to validate.
teacherBatchOutputMatrix<T>Teacher batch output to validate.
Remarks
Checks both batch size (rows) and output dimension (columns) match between student and teacher.
Exceptions
- ArgumentNullException
Thrown when outputs are null.
- ArgumentException
Thrown when dimensions don't match.