Table of Contents

Interface IDistillationStrategy<T>

Namespace
AiDotNet.Interfaces
Assembly
AiDotNet.dll

Defines a strategy for computing knowledge distillation loss between student and teacher models.

public interface IDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Remarks

For Beginners: A distillation strategy determines how to measure the difference between what the student predicts and what the teacher predicts. Different strategies can focus on different aspects: - Response-based: Compare final outputs - Feature-based: Compare intermediate layer features - Relation-based: Compare relationships between samples

The most common approach (Hinton et al., 2015) combines two losses: 1. Hard loss: How well the student matches the true labels 2. Soft loss: How well the student mimics the teacher's predictions

This combination allows the student to both get the right answers (hard loss) and learn the teacher's reasoning (soft loss).

Batch Processing: This interface operates on batches (Matrix<T>) for efficiency. Each row in the matrices represents one sample in the batch.

Interface Design Note: This interface uses a single type parameter <T> for numeric operations. All input/output types are Matrix<T> for batch processing. There is no second type parameter TOutput - the output type is always Matrix<T> for gradients and T for loss values.

Properties

Alpha

Gets or sets the balance parameter (alpha) between hard loss and soft loss.

double Alpha { get; set; }

Property Value

double

Remarks

For Beginners: Alpha controls the trade-off between learning from true labels and learning from the teacher: - α = 0: Only learn from teacher (pure distillation) - α = 0.3-0.5: Balanced (recommended for most cases) - α = 1: Only learn from true labels (standard training, no distillation)

When true labels are noisy or scarce, lower alpha (more weight on teacher) helps. When labels are clean and abundant, higher alpha (more weight on labels) works better.

Temperature

Gets or sets the temperature parameter for softening probability distributions.

double Temperature { get; set; }

Property Value

double

Remarks

For Beginners: Temperature controls how "soft" the predictions become: - T = 1: Normal predictions (standard softmax) - T = 2-10: Softer predictions that reveal more about class relationships - Higher T: Even softer, but gradients become smaller

Typical values: 3-5 for most applications, 2-3 for easier tasks, 5-10 for harder tasks.

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the distillation loss for backpropagation.

Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

Matrix<T>

The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]

Remarks

For Beginners: Gradients tell us how to adjust the student model's parameters to reduce the loss. They point in the direction of steepest increase in loss, so we move in the opposite direction during training.

The gradient combines information from both the teacher (soft targets) and the true labels (hard targets), helping the student learn from both sources.

Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the distillation loss between student and teacher batch outputs.

T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

T

The computed distillation loss value (scalar) for the batch.

Remarks

For Beginners: This calculates how different the student's predictions are from the teacher's predictions across an entire batch of samples. A lower loss means the student is learning well from the teacher.

The formula typically used is: Total Loss = α × Hard Loss + (1 - α) × Soft Loss

Where:

  • Hard Loss: Cross-entropy between student predictions and true labels
  • Soft Loss: KL divergence between student and teacher (with temperature scaling)
  • α (alpha): Balance parameter (typically 0.3-0.5)

Batch Processing: The loss is typically averaged over all samples in the batch.