Class HybridDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Hybrid distillation strategy that combines multiple distillation strategies with configurable weights.
public class HybridDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
HybridDistillationStrategy<T>
- Implements
- Inherited Members
Remarks
For Production Use: This strategy allows you to combine multiple distillation approaches (response-based, feature-based, attention-based, etc.) in a single training run. Each strategy contributes to the total loss based on its configured weight.
Example Use Case: For transformer distillation, combine: - 40% Response-based (output matching) - 30% Attention-based (attention pattern matching) - 30% Feature-based (intermediate layer matching) This gives you comprehensive knowledge transfer at multiple levels.
Benefits: - Leverages multiple knowledge transfer mechanisms simultaneously - Weights can be tuned based on validation performance - More robust than single-strategy distillation - Commonly used in SOTA models like TinyBERT, MobileBERT
Constructors
HybridDistillationStrategy((IDistillationStrategy<T> Strategy, double Weight)[], double, double)
Initializes a new instance of the HybridDistillationStrategy class.
public HybridDistillationStrategy((IDistillationStrategy<T> Strategy, double Weight)[] strategies, double temperature = 3, double alpha = 0.3)
Parameters
strategies(IDistillationStrategy<T> Strategy, double Weight)[]Array of (strategy, weight) tuples. Weights should sum to 1.0.
temperaturedoubleTemperature for strategies that don't specify their own.
alphadoubleAlpha for strategies that don't specify their own.
Remarks
Example:
var hybrid = new HybridDistillationStrategy<double>(
new[] {
(new DistillationLoss<double>(3.0, 0.3), 0.4), // 40% response
(new AttentionDistillationStrategy<double>(...), 0.3), // 30% attention
(new FeatureDistillationStrategy<double>(...), 0.3) // 30% features
}
);
Exceptions
- ArgumentException
Thrown if weights don't sum to approximately 1.0 or strategies is empty.
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes combined gradient from all strategies.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes combined loss from all strategies.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- T
GetStrategies()
Gets the individual strategies and their weights.
public (IDistillationStrategy<T> Strategy, double Weight)[] GetStrategies()
Returns
- (IDistillationStrategy<T> Strategy, double Weight)[]