Class VariationalDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Variational distillation based on variational inference principles and information theory.
public class VariationalDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
VariationalDistillationStrategy<T>
- Implements
- Inherited Members
Remarks
For Production Use: This strategy applies variational inference to knowledge distillation, treating representations as distributions rather than point estimates. It implements concepts from Variational Information Bottleneck (VIB) and Variational Autoencoders (VAE) for distillation.
Key Concept: Instead of matching point predictions, we model representations as probability distributions (typically Gaussian) and match these distributions. This captures uncertainty and enables: 1. Latent space alignment - Match distributions in hidden layers 2. Information bottleneck - Compress while preserving task-relevant information 3. Uncertainty quantification - Transfer confidence estimates
Implementation: We provide three variational modes: - ELBO: Match Evidence Lower Bound (reconstruction + KL) - InformationBottleneck: Minimize I(Z;X) while maximizing I(Z;Y) where Z is representation - LatentSpaceKL: Match KL divergence in latent space between teacher and student
Mathematical Foundation: For Gaussian distributions N(μ,σ²), the KL divergence is: KL(P||Q) = log(σ_q/σ_p) + (σ_p² + (μ_p - μ_q)²)/(2σ_q²) - 1/2
The VIB objective: min I(X;Z) - βI(Z;Y) where β controls the trade-off.
Research Basis: Based on: - Variational Information Bottleneck (Alemi et al., 2017) - Variational Knowledge Distillation (Ahn et al., 2019) - Bayesian Dark Knowledge (Balan et al., 2015)
Constructors
VariationalDistillationStrategy(double, VariationalMode, double, double, double)
public VariationalDistillationStrategy(double variationalWeight = 0.5, VariationalMode mode = VariationalMode.LatentSpaceKL, double betaIB = 1, double temperature = 3, double alpha = 0.3)
Parameters
variationalWeightdoublemodeVariationalModebetaIBdoubletemperaturedoublealphadouble
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the distillation loss for backpropagation.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- Matrix<T>
The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]
Remarks
For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.
Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the distillation loss between student and teacher batch outputs.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- T
The computed distillation loss value (scalar) for the batch.
Remarks
For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.
Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.
ComputeVariationalLoss(Vector<T>, Vector<T>, Vector<T>, Vector<T>)
Computes variational loss using latent representations with mean and variance.
public T ComputeVariationalLoss(Vector<T> studentMean, Vector<T> studentLogVar, Vector<T> teacherMean, Vector<T> teacherLogVar)
Parameters
studentMeanVector<T>Student's mean vector in latent space.
studentLogVarVector<T>Student's log variance vector in latent space.
teacherMeanVector<T>Teacher's mean vector in latent space.
teacherLogVarVector<T>Teacher's log variance vector in latent space.
Returns
- T
Variational loss based on selected mode.
Remarks
Representations should be parameterized as Gaussian distributions with mean and log variance. Log variance is used for numerical stability (variance must be positive).
ComputeVariationalLossBatch(Vector<T>[], Vector<T>[], Vector<T>[], Vector<T>[])
Computes variational loss for batch of latent representations.
public T ComputeVariationalLossBatch(Vector<T>[] studentMeans, Vector<T>[] studentLogVars, Vector<T>[] teacherMeans, Vector<T>[] teacherLogVars)
Parameters
studentMeansVector<T>[]studentLogVarsVector<T>[]teacherMeansVector<T>[]teacherLogVarsVector<T>[]
Returns
- T
Reparameterize(Vector<T>, Vector<T>, Vector<T>)
Reparameterization trick for sampling from latent distribution during training.
public Vector<T> Reparameterize(Vector<T> mean, Vector<T> logVar, Vector<T> epsilon)
Parameters
meanVector<T>Mean of the distribution.
logVarVector<T>Log variance of the distribution.
epsilonVector<T>Random noise from standard normal N(0,1).
Returns
- Vector<T>
Sample z = μ + σ * ε