Table of Contents

Class HybridDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Hybrid distillation strategy that combines multiple distillation strategies with configurable weights.

public class HybridDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
HybridDistillationStrategy<T>
Implements
Inherited Members

Remarks

For Production Use: This strategy allows you to combine multiple distillation approaches (response-based, feature-based, attention-based, etc.) in a single training run. Each strategy contributes to the total loss based on its configured weight.

Example Use Case: For transformer distillation, combine: - 40% Response-based (output matching) - 30% Attention-based (attention pattern matching) - 30% Feature-based (intermediate layer matching) This gives you comprehensive knowledge transfer at multiple levels.

Benefits: - Leverages multiple knowledge transfer mechanisms simultaneously - Weights can be tuned based on validation performance - More robust than single-strategy distillation - Commonly used in SOTA models like TinyBERT, MobileBERT

Constructors

HybridDistillationStrategy((IDistillationStrategy<T> Strategy, double Weight)[], double, double)

Initializes a new instance of the HybridDistillationStrategy class.

public HybridDistillationStrategy((IDistillationStrategy<T> Strategy, double Weight)[] strategies, double temperature = 3, double alpha = 0.3)

Parameters

strategies (IDistillationStrategy<T> Strategy, double Weight)[]

Array of (strategy, weight) tuples. Weights should sum to 1.0.

temperature double

Temperature for strategies that don't specify their own.

alpha double

Alpha for strategies that don't specify their own.

Remarks

Example:

var hybrid = new HybridDistillationStrategy<double>(
    new[] {
        (new DistillationLoss<double>(3.0, 0.3), 0.4),        // 40% response
        (new AttentionDistillationStrategy<double>(...), 0.3),  // 30% attention
        (new FeatureDistillationStrategy<double>(...), 0.3)     // 30% features
    }
);

Exceptions

ArgumentException

Thrown if weights don't sum to approximately 1.0 or strategies is empty.

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes combined gradient from all strategies.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

Matrix<T>

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes combined loss from all strategies.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

T

GetStrategies()

Gets the individual strategies and their weights.

public (IDistillationStrategy<T> Strategy, double Weight)[] GetStrategies()

Returns

(IDistillationStrategy<T> Strategy, double Weight)[]