Class ProbabilisticDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Probabilistic distillation that transfers distributional knowledge by matching statistical properties.
public class ProbabilisticDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
ProbabilisticDistillationStrategy<T>
- Implements
- Inherited Members
Remarks
For Production Use: This strategy treats model outputs as samples from probability distributions and transfers knowledge about the entire distribution, not just point predictions. It matches statistical moments (mean, variance, higher moments) and can use measures like Maximum Mean Discrepancy (MMD).
Key Concept: Standard distillation matches individual predictions, but neural networks can be viewed as probabilistic models. This strategy captures uncertainty and distribution shape by matching: 1. First moment (mean) - Expected predictions 2. Second moment (variance) - Prediction uncertainty 3. Distribution distance (MMD, Wasserstein) - Overall shape
Implementation: We provide three modes: - MomentMatching: Match mean and variance of predictions across batch - MaximumMeanDiscrepancy: Use MMD with RBF kernel to match distributions - EntropyTransfer: Match prediction entropy (uncertainty calibration)
Research Basis: Based on probabilistic knowledge distillation and Bayesian neural networks. Particularly useful for uncertainty quantification and ensemble distillation.
Constructors
ProbabilisticDistillationStrategy(double, ProbabilisticMode, double, double, double)
public ProbabilisticDistillationStrategy(double distributionWeight = 0.5, ProbabilisticMode mode = ProbabilisticMode.MomentMatching, double mmdBandwidth = 1, double temperature = 3, double alpha = 0.3)
Parameters
distributionWeightdoublemodeProbabilisticModemmdBandwidthdoubletemperaturedoublealphadouble
Methods
ComputeDistributionalLoss(Vector<T>[], Vector<T>[])
Computes distributional loss by matching statistical properties across a batch.
public T ComputeDistributionalLoss(Vector<T>[] studentPredictions, Vector<T>[] teacherPredictions)
Parameters
studentPredictionsVector<T>[]Student probability distributions for a batch.
teacherPredictionsVector<T>[]Teacher probability distributions for a batch.
Returns
- T
Distributional matching loss.
Remarks
This should be called with predictions (post-softmax) for an entire batch. The method will compute distributional statistics and match them between student and teacher.
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the distillation loss for backpropagation.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- Matrix<T>
The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]
Remarks
For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.
Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the distillation loss between student and teacher batch outputs.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- T
The computed distillation loss value (scalar) for the batch.
Remarks
For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.
Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.