Class DistillationLoss<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
public class DistillationLoss<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
T
- Inheritance
-
DistillationLoss<T>
- Implements
- Inherited Members
Constructors
DistillationLoss(double, double)
Initializes a new instance of the DistillationLoss class.
public DistillationLoss(double temperature = 3, double alpha = 0.3)
Parameters
temperaturedoubleSoftmax temperature for distillation (default: 3.0). Higher values (2-10) produce softer probability distributions that transfer more knowledge.
alphadoubleBalance between hard loss and soft loss (default: 0.3). Lower values give more weight to the teacher's knowledge.
Remarks
For Beginners: The default values (temperature=3.0, alpha=0.3) work well for most classification tasks. You may want to adjust them based on your specific problem: - Increase temperature if the teacher's uncertainty is important (complex tasks) - Decrease alpha if you have noisy labels or want to rely more on the teacher - Increase alpha if you have very clean labels and a smaller capacity gap
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the distillation loss for backpropagation.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's raw outputs (logits).
teacherBatchOutputMatrix<T>The teacher model's raw outputs (logits).
trueLabelsBatchMatrix<T>Ground truth labels (optional).
Returns
- Matrix<T>
Gradient matrix with respect to student logits.
Remarks
For Beginners: The gradient tells us how to adjust the student's parameters to reduce the loss. It points in the direction that increases loss, so we subtract it (gradient descent) to improve the model.
The soft gradient: (student_soft - teacher_soft) × T² The hard gradient: (student_probs - true_labels) Combined gradient: α × hard_grad + (1 - α) × soft_grad
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the combined distillation loss (soft loss from teacher + hard loss from true labels).
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's raw outputs (logits) before softmax.
teacherBatchOutputMatrix<T>The teacher model's raw outputs (logits) before softmax.
trueLabelsBatchMatrix<T>Ground truth labels as one-hot vectors (optional). If null, only soft loss is computed.
Returns
- T
The total distillation loss combining soft and hard components.
Remarks
For Beginners: This is the main loss function that guides student training. It combines two objectives: 1. Match the teacher's soft predictions (learn the teacher's knowledge) 2. Match the true labels (learn to be correct)
The soft loss uses KL divergence, which measures how different two probability distributions are. When the student's soft predictions match the teacher's, KL divergence approaches zero.