Table of Contents

Class ProbabilisticDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Probabilistic distillation that transfers distributional knowledge by matching statistical properties.

public class ProbabilisticDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
ProbabilisticDistillationStrategy<T>
Implements
Inherited Members

Remarks

For Production Use: This strategy treats model outputs as samples from probability distributions and transfers knowledge about the entire distribution, not just point predictions. It matches statistical moments (mean, variance, higher moments) and can use measures like Maximum Mean Discrepancy (MMD).

Key Concept: Standard distillation matches individual predictions, but neural networks can be viewed as probabilistic models. This strategy captures uncertainty and distribution shape by matching: 1. First moment (mean) - Expected predictions 2. Second moment (variance) - Prediction uncertainty 3. Distribution distance (MMD, Wasserstein) - Overall shape

Implementation: We provide three modes: - MomentMatching: Match mean and variance of predictions across batch - MaximumMeanDiscrepancy: Use MMD with RBF kernel to match distributions - EntropyTransfer: Match prediction entropy (uncertainty calibration)

Research Basis: Based on probabilistic knowledge distillation and Bayesian neural networks. Particularly useful for uncertainty quantification and ensemble distillation.

Constructors

ProbabilisticDistillationStrategy(double, ProbabilisticMode, double, double, double)

public ProbabilisticDistillationStrategy(double distributionWeight = 0.5, ProbabilisticMode mode = ProbabilisticMode.MomentMatching, double mmdBandwidth = 1, double temperature = 3, double alpha = 0.3)

Parameters

distributionWeight double
mode ProbabilisticMode
mmdBandwidth double
temperature double
alpha double

Methods

ComputeDistributionalLoss(Vector<T>[], Vector<T>[])

Computes distributional loss by matching statistical properties across a batch.

public T ComputeDistributionalLoss(Vector<T>[] studentPredictions, Vector<T>[] teacherPredictions)

Parameters

studentPredictions Vector<T>[]

Student probability distributions for a batch.

teacherPredictions Vector<T>[]

Teacher probability distributions for a batch.

Returns

T

Distributional matching loss.

Remarks

This should be called with predictions (post-softmax) for an entire batch. The method will compute distributional statistics and match them between student and teacher.

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the distillation loss for backpropagation.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

Matrix<T>

The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]

Remarks

For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.

Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the distillation loss between student and teacher batch outputs.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

T

The computed distillation loss value (scalar) for the batch.

Remarks

For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.

Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.