Class CurriculumDistillationStrategyBase<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Abstract base class for curriculum distillation strategies with progressive difficulty adjustment.
public abstract class CurriculumDistillationStrategyBase<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, ICurriculumDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
CurriculumDistillationStrategyBase<T>
- Implements
- Derived
- Inherited Members
Remarks
For Beginners: This base class provides common functionality for curriculum learning, including progress tracking, sample difficulty management, and temperature progression.
For Implementers: Derive from this class and implement ComputeCurriculumTemperature() and ShouldIncludeSample(int) to define your specific curriculum progression logic.
Shared Features: - Curriculum progress tracking (0.0 to 1.0) - Sample difficulty scoring and management - Temperature range validation - Step/epoch-based progression
Constructors
CurriculumDistillationStrategyBase(double, double, double, double, int, Dictionary<int, double>?)
Initializes a new instance of the CurriculumDistillationStrategyBase class.
protected CurriculumDistillationStrategyBase(double baseTemperature = 3, double alpha = 0.3, double minTemperature = 2, double maxTemperature = 5, int totalSteps = 100, Dictionary<int, double>? sampleDifficulties = null)
Parameters
baseTemperaturedoubleBase temperature for distillation (default: 3.0).
alphadoubleBalance between hard and soft loss (default: 0.3).
minTemperaturedoubleMinimum temperature for curriculum (default: 2.0).
maxTemperaturedoubleMaximum temperature for curriculum (default: 5.0).
totalStepsintTotal training steps/epochs (default: 100).
sampleDifficultiesDictionary<int, double>Optional pre-defined difficulty scores.
Properties
CurriculumProgress
Gets the current curriculum progress (0.0 to 1.0).
public double CurriculumProgress { get; }
Property Value
MaxTemperature
Gets the maximum temperature for the curriculum.
public double MaxTemperature { get; }
Property Value
MinTemperature
Gets the minimum temperature for the curriculum.
public double MinTemperature { get; }
Property Value
TotalSteps
Gets the total number of steps in the curriculum.
public int TotalSteps { get; }
Property Value
Methods
ClampTemperature(double)
Clamps a value to the temperature range [MinTemperature, MaxTemperature].
protected double ClampTemperature(double temperature)
Parameters
temperaturedouble
Returns
ComputeCurriculumTemperature()
Computes the curriculum-adjusted temperature based on current progress.
public abstract double ComputeCurriculumTemperature()
Returns
Remarks
For Implementers: Override to define strategy-specific temperature progression.
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes gradient with curriculum-adjusted temperature.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes distillation loss with curriculum-adjusted temperature.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- T
GetSampleDifficulty(int)
Gets the difficulty score for a sample, if set.
public virtual double? GetSampleDifficulty(int sampleIndex)
Parameters
sampleIndexint
Returns
SetSampleDifficulty(int, double)
Sets the difficulty score for a specific sample.
public virtual void SetSampleDifficulty(int sampleIndex, double difficulty)
Parameters
ShouldIncludeSample(int)
Determines if a sample should be included in current curriculum stage.
public abstract bool ShouldIncludeSample(int sampleIndex)
Parameters
sampleIndexint
Returns
Remarks
For Implementers: Override to define strategy-specific sample filtering logic.
Default: Includes all samples (no filtering).
UpdateProgress(int)
Updates the current curriculum progress.
public virtual void UpdateProgress(int step)
Parameters
stepint