Table of Contents

Class FactorTransferDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll
public class FactorTransferDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>

Type Parameters

T
Inheritance
FactorTransferDistillationStrategy<T>
Implements
Inherited Members

Constructors

FactorTransferDistillationStrategy(double, FactorMode, int, bool, double, double)

public FactorTransferDistillationStrategy(double factorWeight = 0.5, FactorMode mode = FactorMode.LowRankApproximation, int numFactors = 32, bool normalizeFactors = true, double temperature = 3, double alpha = 0.3)

Parameters

factorWeight double
mode FactorMode
numFactors int
normalizeFactors bool
temperature double
alpha double

Methods

ComputeFactorLoss(Vector<T>[], Vector<T>[])

Computes factor transfer loss by matching factorized representations.

public T ComputeFactorLoss(Vector<T>[] studentFeatures, Vector<T>[] teacherFeatures)

Parameters

studentFeatures Vector<T>[]

Student feature matrix [batchSize x featureDim].

teacherFeatures Vector<T>[]

Teacher feature matrix [batchSize x featureDim].

Returns

T

Factor transfer loss.

Remarks

Features should be intermediate layer activations collected over a batch. The method will factorize both representations and match the factors.

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the distillation loss for backpropagation.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

Matrix<T>

The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]

Remarks

For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.

Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the distillation loss between student and teacher batch outputs.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

T

The computed distillation loss value (scalar) for the batch.

Remarks

For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.

Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.