Table of Contents

Class DistillationStrategyBase<T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Abstract base class for knowledge distillation strategies. Provides common functionality for computing losses and gradients in student-teacher training.

public abstract class DistillationStrategyBase<T> : IDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
DistillationStrategyBase<T>
Implements
Derived
Inherited Members

Remarks

For Beginners: A distillation strategy defines how to measure the difference between student and teacher predictions. This base class provides common functionality that all strategies need, like temperature and alpha parameters.

Design Philosophy: Different distillation strategies focus on different aspects: - **Response-based**: Match final outputs (logits/probabilities) - **Feature-based**: Match intermediate layer representations - **Relation-based**: Match relationships between samples - **Attention-based**: Match attention patterns (for transformers)

This base class ensures all strategies handle temperature and alpha consistently, while allowing flexibility in how loss is computed.

Batch Processing: All strategies now operate on batches (Matrix<T>) for efficiency. Each row in the matrices represents one sample in the batch.

Template Method Pattern: The base class defines the structure (properties, validation), and subclasses implement the specifics (loss computation logic).

Constructors

DistillationStrategyBase(double, double)

Initializes a new instance of the distillation strategy base class.

protected DistillationStrategyBase(double temperature = 3, double alpha = 0.3)

Parameters

temperature double

Softmax temperature (default 3.0).

alpha double

Balance between hard and soft loss (default 0.3).

Fields

Epsilon

Gets the epsilon value for numerical stability (to avoid log(0), division by zero, etc.).

protected const double Epsilon = 1E-10

Field Value

double

NumOps

Numeric operations for the specified type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

Alpha

Gets or sets the balance parameter between hard loss and soft loss.

public double Alpha { get; set; }

Property Value

double

Remarks

For Beginners: Alpha controls the trade-off between learning from true labels (hard loss) and learning from the teacher (soft loss).

Alpha values: - α = 0: Only learn from teacher (pure distillation) - α = 0.3-0.5: Balanced (recommended) - α = 1: Only learn from labels (no distillation)

Validation: Must be between 0 and 1. Setting invalid values throws ArgumentException.

Temperature

Gets or sets the temperature parameter for softening probability distributions.

public double Temperature { get; set; }

Property Value

double

Remarks

For Beginners: Higher temperature makes predictions "softer", revealing more about the model's uncertainty and class relationships.

Temperature effects: - T = 1: Standard predictions (sharp) - T = 2-5: Softer predictions (recommended for distillation) - T > 10: Very soft (may be too smooth)

Validation: Must be positive (> 0). Setting invalid values throws ArgumentException.

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the distillation loss for backpropagation.

public abstract Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

Matrix<T>

The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]

Remarks

For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.

Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the distillation loss between student and teacher batch outputs.

public abstract T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

T

The computed distillation loss value (scalar) for the batch.

Remarks

For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.

Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.

ValidateLabelDimensions(Matrix<T>, Matrix<T>?)

Validates that batch outputs and labels have matching dimensions (if labels are provided).

protected void ValidateLabelDimensions(Matrix<T> batchOutput, Matrix<T>? labelsBatch)

Parameters

batchOutput Matrix<T>

Batch output to validate.

labelsBatch Matrix<T>

Labels batch to validate (can be null).

Remarks

If labels are null, validation is skipped (for pure soft distillation without labels).

Exceptions

ArgumentException

Thrown when dimensions don't match.

ValidateOutputDimensions(Matrix<T>, Matrix<T>)

Validates that student and teacher batch outputs have matching dimensions.

protected void ValidateOutputDimensions(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput)

Parameters

studentBatchOutput Matrix<T>

Student batch output to validate.

teacherBatchOutput Matrix<T>

Teacher batch output to validate.

Remarks

Checks both batch size (rows) and output dimension (columns) match between student and teacher.

Exceptions

ArgumentNullException

Thrown when outputs are null.

ArgumentException

Thrown when dimensions don't match.