Class AdaptiveDistillationStrategyBase<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Abstract base class for adaptive distillation strategies with performance tracking.
public abstract class AdaptiveDistillationStrategyBase<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IAdaptiveDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
AdaptiveDistillationStrategyBase<T>
- Implements
- Derived
- Inherited Members
Remarks
For Beginners: This base class provides common functionality for all adaptive strategies, including performance tracking with exponential moving average and temperature range management.
For Implementers: Derive from this class and implement ComputeAdaptiveTemperature(Vector<T>, Vector<T>) to define your specific adaptation logic.
Shared Features: - Exponential moving average (EMA) for performance tracking - Temperature range validation and enforcement - Performance history management - Helper methods for confidence, entropy, and accuracy calculations
Constructors
AdaptiveDistillationStrategyBase(double, double, double, double, double)
Initializes a new instance of the AdaptiveDistillationStrategyBase class.
protected AdaptiveDistillationStrategyBase(double baseTemperature = 3, double alpha = 0.3, double minTemperature = 1, double maxTemperature = 5, double adaptationRate = 0.1)
Parameters
baseTemperaturedoubleBase temperature for distillation (default: 3.0).
alphadoubleBalance between hard and soft loss (default: 0.3).
minTemperaturedoubleMinimum temperature for adaptation (default: 1.0).
maxTemperaturedoubleMaximum temperature for adaptation (default: 5.0).
adaptationRatedoubleRate for EMA performance tracking (default: 0.1).
Properties
AdaptationRate
Gets the adaptation rate for exponential moving average.
public double AdaptationRate { get; }
Property Value
MaxTemperature
Gets the maximum temperature for adaptation.
public double MaxTemperature { get; }
Property Value
MinTemperature
Gets the minimum temperature for adaptation.
public double MinTemperature { get; }
Property Value
Methods
ArgMax(Vector<T>)
Finds the index of the maximum value in a vector.
protected int ArgMax(Vector<T> vector)
Parameters
vectorVector<T>
Returns
ClampTemperature(double)
Clamps a value to the temperature range [MinTemperature, MaxTemperature].
protected double ClampTemperature(double temperature)
Parameters
temperaturedouble
Returns
ComputeAdaptiveTemperature(Vector<T>, Vector<T>)
Computes the adaptive temperature for a specific sample.
public abstract double ComputeAdaptiveTemperature(Vector<T> studentOutput, Vector<T> teacherOutput)
Parameters
studentOutputVector<T>teacherOutputVector<T>
Returns
Remarks
For Implementers: Override this to define strategy-specific temperature adaptation.
ComputeEntropy(Vector<T>)
Computes the entropy of a probability distribution.
protected double ComputeEntropy(Vector<T> probabilities)
Parameters
probabilitiesVector<T>
Returns
Remarks
Entropy measures uncertainty. Higher entropy = more uncertain = harder sample.
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes gradient with adaptive temperature.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes distillation loss with adaptive temperature.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- T
ComputePerformance(Vector<T>, Vector<T>?)
Computes a performance metric for the student output.
protected virtual double ComputePerformance(Vector<T> studentOutput, Vector<T>? trueLabel)
Parameters
studentOutputVector<T>trueLabelVector<T>
Returns
Remarks
For Implementers: Override to define strategy-specific performance metrics.
Default: Returns max confidence (highest probability).
GetMaxConfidence(Vector<T>)
Gets the maximum confidence (highest probability) from a probability distribution.
protected double GetMaxConfidence(Vector<T> probabilities)
Parameters
probabilitiesVector<T>
Returns
GetPerformance(int)
Gets the current performance metric for a sample.
public virtual double GetPerformance(int sampleIndex)
Parameters
sampleIndexint
Returns
IsCorrect(Vector<T>, Vector<T>)
Checks if the student prediction is correct.
protected bool IsCorrect(Vector<T> studentOutput, Vector<T> trueLabel)
Parameters
studentOutputVector<T>trueLabelVector<T>
Returns
UpdatePerformance(int, Vector<T>, Vector<T>?)
Updates the performance metric for a specific sample using exponential moving average.
public virtual void UpdatePerformance(int sampleIndex, Vector<T> studentOutput, Vector<T>? trueLabel = null)
Parameters
sampleIndexintstudentOutputVector<T>trueLabelVector<T>