Class DistillationCheckpointManager<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Manages checkpointing during knowledge distillation training.
public class DistillationCheckpointManager<T>
Type Parameters
TThe numeric type for calculations.
- Inheritance
-
DistillationCheckpointManager<T>
- Inherited Members
Remarks
For Beginners: This class handles saving and loading model states during distillation training. It's like the "save game" manager in a video game - it decides when to save, what to save, and how to load progress later.
Key Features: - Automatic checkpointing at specified intervals - Keep only best N checkpoints based on validation metrics - Save/restore curriculum learning progress - Support for multi-stage distillation (student → teacher) - Resume interrupted training
Example Usage:
var config = new DistillationCheckpointConfig
{
CheckpointDirectory = "./checkpoints",
SaveEveryEpochs = 5,
KeepBestN = 3,
SaveStudent = true
};
var manager = new DistillationCheckpointManager<double>(config);
// During training
for (int epoch = 0; epoch < 100; epoch++)
{
// ... training code ...
double validationLoss = EvaluateStudent();
manager.SaveCheckpointIfNeeded(
epoch: epoch,
student: studentModel,
metrics: new Dictionary<string, double> { { "validation_loss", validationLoss } }
);
}
// Load best checkpoint
manager.LoadBestCheckpoint(studentModel);
Constructors
DistillationCheckpointManager(DistillationCheckpointConfig)
Initializes a new instance of the DistillationCheckpointManager class.
public DistillationCheckpointManager(DistillationCheckpointConfig config)
Parameters
configDistillationCheckpointConfigConfiguration for checkpoint management.
Remarks
For Advanced Users: This constructor requires explicit configuration. All parameters have recommended defaults in DistillationCheckpointConfig.
Exceptions
- ArgumentNullException
Thrown when config is null.
Methods
GetAllCheckpoints()
Gets all saved checkpoint metadata as a readonly collection.
public IReadOnlyList<CheckpointMetadata> GetAllCheckpoints()
Returns
- IReadOnlyList<CheckpointMetadata>
Readonly list of all checkpoint metadata.
Remarks
For Advanced Users: Provides read-only access to all saved checkpoints for custom queries.
GetBestCheckpoint()
Gets the checkpoint with the best metric value.
public CheckpointMetadata? GetBestCheckpoint()
Returns
- CheckpointMetadata
Metadata of the best checkpoint, or null if no checkpoints exist.
Remarks
For Advanced Users: Returns the checkpoint with the best validation metric based on the configuration (e.g., lowest validation loss or highest accuracy).
GetCheckpointByEpoch(int)
Gets a checkpoint for a specific epoch.
public CheckpointMetadata? GetCheckpointByEpoch(int epoch)
Parameters
epochintThe epoch number to find.
Returns
- CheckpointMetadata
Metadata of the checkpoint at the specified epoch, or null if not found.
Remarks
For Advanced Users: Returns the checkpoint saved at a specific epoch number.
GetMostRecentCheckpoint()
Gets the most recently saved checkpoint.
public CheckpointMetadata? GetMostRecentCheckpoint()
Returns
- CheckpointMetadata
Metadata of the most recent checkpoint, or null if no checkpoints exist.
Remarks
For Advanced Users: Useful for resuming interrupted training from the last saved state.
LoadBestCheckpoint(ICheckpointableModel?, ICheckpointableModel?)
Loads the best checkpoint based on the configured metric.
public CheckpointMetadata? LoadBestCheckpoint(ICheckpointableModel? student = null, ICheckpointableModel? teacher = null)
Parameters
studentICheckpointableModelStudent model to load into.
teacherICheckpointableModelOptional teacher model to load into.
Returns
- CheckpointMetadata
Metadata of the loaded checkpoint, or null if no checkpoints exist.
Remarks
For Beginners: Call this after training to load the checkpoint with the best validation performance.
LoadCheckpoint(CheckpointMetadata, ICheckpointableModel?, ICheckpointableModel?)
Loads a specific checkpoint.
public void LoadCheckpoint(CheckpointMetadata metadata, ICheckpointableModel? student = null, ICheckpointableModel? teacher = null)
Parameters
metadataCheckpointMetadataMetadata of the checkpoint to load.
studentICheckpointableModelStudent model to load into.
teacherICheckpointableModelOptional teacher model to load into.
SaveCheckpointIfNeeded(int, ICheckpointableModel?, ICheckpointableModel?, object?, Dictionary<string, double>?, int?, bool)
Saves a checkpoint if conditions are met.
public bool SaveCheckpointIfNeeded(int epoch, ICheckpointableModel? student = null, ICheckpointableModel? teacher = null, object? strategy = null, Dictionary<string, double>? metrics = null, int? batch = null, bool force = false)
Parameters
epochintCurrent epoch number.
studentICheckpointableModelStudent model to checkpoint (if SaveStudent = true).
teacherICheckpointableModelTeacher model to checkpoint (if SaveTeacher = true).
strategyobjectDistillation strategy to checkpoint.
metricsDictionary<string, double>Training/validation metrics for this checkpoint.
batchint?Optional batch number.
forceboolForce save regardless of schedule.
Returns
- bool
True if checkpoint was saved.
Remarks
For Beginners: Call this method periodically during training. It will automatically decide whether to save based on your configuration.