Class CurriculumLearner<T, TInput, TOutput>
- Namespace
- AiDotNet.CurriculumLearning
- Assembly
- AiDotNet.dll
Main orchestrator for curriculum learning that coordinates difficulty estimation, scheduling, and model training.
public class CurriculumLearner<T, TInput, TOutput> : ICurriculumLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations.
TInputThe type of input features.
TOutputThe type of output labels.
- Inheritance
-
CurriculumLearner<T, TInput, TOutput>
- Implements
-
ICurriculumLearner<T, TInput, TOutput>
- Inherited Members
Remarks
For Beginners: Curriculum learning is a training strategy that presents samples from easy to hard, mimicking how humans learn. This class coordinates the entire process:
- Estimates difficulty of each training sample
- Sorts samples from easy to hard
- Trains the model on progressively harder samples
- Monitors progress and can stop early if model stops improving
Key Components:
- DifficultyEstimator: Measures how hard each sample is
- Scheduler: Controls when to introduce harder samples
- Config: Settings for training (epochs, batch size, early stopping)
Example Usage:
var learner = new CurriculumLearner<double, Vector<double>, double>(
model,
config,
new LossBasedDifficultyEstimator<double, Vector<double>, double>());
var result = learner.Train(trainingData, validationData);
References:
- Bengio et al. "Curriculum Learning" (ICML 2009)
- Soviany et al. "Curriculum Learning: A Survey" (IJCV 2022)
Constructors
CurriculumLearner(IFullModel<T, TInput, TOutput>, ICurriculumLearnerConfig<T>, IDifficultyEstimator<T, TInput, TOutput>, ICurriculumScheduler<T>?)
Initializes a new instance of the CurriculumLearner<T, TInput, TOutput> class.
public CurriculumLearner(IFullModel<T, TInput, TOutput> baseModel, ICurriculumLearnerConfig<T> config, IDifficultyEstimator<T, TInput, TOutput> difficultyEstimator, ICurriculumScheduler<T>? scheduler = null)
Parameters
baseModelIFullModel<T, TInput, TOutput>The model to train.
configICurriculumLearnerConfig<T>Configuration for curriculum learning.
difficultyEstimatorIDifficultyEstimator<T, TInput, TOutput>Estimator for sample difficulties.
schedulerICurriculumScheduler<T>Optional custom scheduler. If null, created based on config.
Properties
BaseModel
Gets the underlying model being trained.
public IFullModel<T, TInput, TOutput> BaseModel { get; }
Property Value
- IFullModel<T, TInput, TOutput>
Config
Gets the configuration for the curriculum learner.
public ICurriculumLearnerConfig<T> Config { get; }
Property Value
CurrentEpoch
Gets the current epoch number.
public int CurrentEpoch { get; }
Property Value
CurrentPhase
Gets the current training phase (0-1, where 1 means all samples are available).
public T CurrentPhase { get; }
Property Value
- T
DifficultyEstimator
Gets the difficulty estimator used to rank samples.
public IDifficultyEstimator<T, TInput, TOutput> DifficultyEstimator { get; }
Property Value
- IDifficultyEstimator<T, TInput, TOutput>
IsTraining
Gets whether training is currently in progress.
public bool IsTraining { get; }
Property Value
Scheduler
Gets the curriculum scheduler that controls training progression.
public ICurriculumScheduler<T> Scheduler { get; }
Property Value
Methods
AdvancePhase()
Advances the curriculum to the next phase.
public bool AdvancePhase()
Returns
- bool
True if advanced, false if already at final phase.
EstimateDifficulties(IDataset<T, TInput, TOutput>)
Estimates difficulty scores for all samples in a dataset.
public Vector<T> EstimateDifficulties(IDataset<T, TInput, TOutput> dataset)
Parameters
datasetIDataset<T, TInput, TOutput>The dataset to estimate difficulties for.
Returns
- Vector<T>
Vector of difficulty scores (higher = harder).
GetCurrentCurriculumIndices(Vector<T>)
Gets the indices of samples available at the current curriculum phase.
public int[] GetCurrentCurriculumIndices(Vector<T> allDifficulties)
Parameters
allDifficultiesVector<T>Difficulty scores for all samples.
Returns
- int[]
Indices of samples available for training at current phase.
GetPhaseHistory()
Gets the training history.
public IReadOnlyList<CurriculumPhaseResult<T>> GetPhaseHistory()
Returns
- IReadOnlyList<CurriculumPhaseResult<T>>
List of results from each curriculum phase.
Load(string)
Loads the curriculum learner state.
public void Load(string path)
Parameters
pathstringPath to load the state from.
ResetCurriculum()
Resets the curriculum to the initial phase.
public void ResetCurriculum()
Save(string)
Saves the curriculum learner state.
public void Save(string path)
Parameters
pathstringPath to save the state.
Train(IDataset<T, TInput, TOutput>)
Trains the model using curriculum learning.
public CurriculumLearningResult<T> Train(IDataset<T, TInput, TOutput> trainingData)
Parameters
trainingDataIDataset<T, TInput, TOutput>The full training dataset.
Returns
- CurriculumLearningResult<T>
Result containing training metrics and curriculum progression.
Remarks
This method will: 1. Estimate difficulty for all samples 2. Sort samples by difficulty (easy to hard) 3. Train in phases, gradually introducing harder samples 4. Track curriculum progression and model performance
Train(IDataset<T, TInput, TOutput>, IDataset<T, TInput, TOutput>?)
Trains the model with a validation set for monitoring.
public CurriculumLearningResult<T> Train(IDataset<T, TInput, TOutput> trainingData, IDataset<T, TInput, TOutput>? validationData)
Parameters
trainingDataIDataset<T, TInput, TOutput>The full training dataset.
validationDataIDataset<T, TInput, TOutput>The validation dataset.
Returns
- CurriculumLearningResult<T>
Result containing training metrics and curriculum progression.
TrainWithDifficulty(IDataset<T, TInput, TOutput>, Vector<T>)
Trains with pre-computed difficulty scores.
public CurriculumLearningResult<T> TrainWithDifficulty(IDataset<T, TInput, TOutput> trainingData, Vector<T> difficultyScores)
Parameters
trainingDataIDataset<T, TInput, TOutput>The full training dataset.
difficultyScoresVector<T>Pre-computed difficulty scores for each sample.
Returns
- CurriculumLearningResult<T>
Result containing training metrics and curriculum progression.
Remarks
Use this method when: - Difficulty scores are computed externally (e.g., by domain experts) - You want to reuse difficulty scores across training runs - The difficulty estimation is too expensive to repeat
Events
PhaseCompleted
Event raised when a curriculum phase completes.
public event EventHandler<CurriculumPhaseCompletedEventArgs<T>>? PhaseCompleted
Event Type
PhaseStarted
Event raised when a curriculum phase starts.
public event EventHandler<CurriculumPhaseEventArgs<T>>? PhaseStarted
Event Type
TrainingCompleted
Event raised when training completes.
public event EventHandler<CurriculumTrainingCompletedEventArgs<T>>? TrainingCompleted