Class FactorTransferDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
public class FactorTransferDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
T
- Inheritance
-
FactorTransferDistillationStrategy<T>
- Implements
- Inherited Members
Constructors
FactorTransferDistillationStrategy(double, FactorMode, int, bool, double, double)
public FactorTransferDistillationStrategy(double factorWeight = 0.5, FactorMode mode = FactorMode.LowRankApproximation, int numFactors = 32, bool normalizeFactors = true, double temperature = 3, double alpha = 0.3)
Parameters
factorWeightdoublemodeFactorModenumFactorsintnormalizeFactorsbooltemperaturedoublealphadouble
Methods
ComputeFactorLoss(Vector<T>[], Vector<T>[])
Computes factor transfer loss by matching factorized representations.
public T ComputeFactorLoss(Vector<T>[] studentFeatures, Vector<T>[] teacherFeatures)
Parameters
studentFeaturesVector<T>[]Student feature matrix [batchSize x featureDim].
teacherFeaturesVector<T>[]Teacher feature matrix [batchSize x featureDim].
Returns
- T
Factor transfer loss.
Remarks
Features should be intermediate layer activations collected over a batch. The method will factorize both representations and match the factors.
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the distillation loss for backpropagation.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- Matrix<T>
The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]
Remarks
For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.
Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the distillation loss between student and teacher batch outputs.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- T
The computed distillation loss value (scalar) for the batch.
Remarks
For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.
Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.