Table of Contents

Class VariationalDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Variational distillation based on variational inference principles and information theory.

public class VariationalDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
VariationalDistillationStrategy<T>
Implements
Inherited Members

Remarks

For Production Use: This strategy applies variational inference to knowledge distillation, treating representations as distributions rather than point estimates. It implements concepts from Variational Information Bottleneck (VIB) and Variational Autoencoders (VAE) for distillation.

Key Concept: Instead of matching point predictions, we model representations as probability distributions (typically Gaussian) and match these distributions. This captures uncertainty and enables: 1. Latent space alignment - Match distributions in hidden layers 2. Information bottleneck - Compress while preserving task-relevant information 3. Uncertainty quantification - Transfer confidence estimates

Implementation: We provide three variational modes: - ELBO: Match Evidence Lower Bound (reconstruction + KL) - InformationBottleneck: Minimize I(Z;X) while maximizing I(Z;Y) where Z is representation - LatentSpaceKL: Match KL divergence in latent space between teacher and student

Mathematical Foundation: For Gaussian distributions N(μ,σ²), the KL divergence is: KL(P||Q) = log(σ_q/σ_p) + (σ_p² + (μ_p - μ_q)²)/(2σ_q²) - 1/2

The VIB objective: min I(X;Z) - βI(Z;Y) where β controls the trade-off.

Research Basis: Based on: - Variational Information Bottleneck (Alemi et al., 2017) - Variational Knowledge Distillation (Ahn et al., 2019) - Bayesian Dark Knowledge (Balan et al., 2015)

Constructors

VariationalDistillationStrategy(double, VariationalMode, double, double, double)

public VariationalDistillationStrategy(double variationalWeight = 0.5, VariationalMode mode = VariationalMode.LatentSpaceKL, double betaIB = 1, double temperature = 3, double alpha = 0.3)

Parameters

variationalWeight double
mode VariationalMode
betaIB double
temperature double
alpha double

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the distillation loss for backpropagation.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

Matrix<T>

The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]

Remarks

For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.

Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the distillation loss between student and teacher batch outputs.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

T

The computed distillation loss value (scalar) for the batch.

Remarks

For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.

Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.

ComputeVariationalLoss(Vector<T>, Vector<T>, Vector<T>, Vector<T>)

Computes variational loss using latent representations with mean and variance.

public T ComputeVariationalLoss(Vector<T> studentMean, Vector<T> studentLogVar, Vector<T> teacherMean, Vector<T> teacherLogVar)

Parameters

studentMean Vector<T>

Student's mean vector in latent space.

studentLogVar Vector<T>

Student's log variance vector in latent space.

teacherMean Vector<T>

Teacher's mean vector in latent space.

teacherLogVar Vector<T>

Teacher's log variance vector in latent space.

Returns

T

Variational loss based on selected mode.

Remarks

Representations should be parameterized as Gaussian distributions with mean and log variance. Log variance is used for numerical stability (variance must be positive).

ComputeVariationalLossBatch(Vector<T>[], Vector<T>[], Vector<T>[], Vector<T>[])

Computes variational loss for batch of latent representations.

public T ComputeVariationalLossBatch(Vector<T>[] studentMeans, Vector<T>[] studentLogVars, Vector<T>[] teacherMeans, Vector<T>[] teacherLogVars)

Parameters

studentMeans Vector<T>[]
studentLogVars Vector<T>[]
teacherMeans Vector<T>[]
teacherLogVars Vector<T>[]

Returns

T

Reparameterize(Vector<T>, Vector<T>, Vector<T>)

Reparameterization trick for sampling from latent distribution during training.

public Vector<T> Reparameterize(Vector<T> mean, Vector<T> logVar, Vector<T> epsilon)

Parameters

mean Vector<T>

Mean of the distribution.

logVar Vector<T>

Log variance of the distribution.

epsilon Vector<T>

Random noise from standard normal N(0,1).

Returns

Vector<T>

Sample z = μ + σ * ε