Table of Contents

Class DistillationCheckpointManager<T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Manages checkpointing during knowledge distillation training.

public class DistillationCheckpointManager<T>

Type Parameters

T

The numeric type for calculations.

Inheritance
DistillationCheckpointManager<T>
Inherited Members

Remarks

For Beginners: This class handles saving and loading model states during distillation training. It's like the "save game" manager in a video game - it decides when to save, what to save, and how to load progress later.

Key Features: - Automatic checkpointing at specified intervals - Keep only best N checkpoints based on validation metrics - Save/restore curriculum learning progress - Support for multi-stage distillation (student → teacher) - Resume interrupted training

Example Usage:

var config = new DistillationCheckpointConfig
{
    CheckpointDirectory = "./checkpoints",
    SaveEveryEpochs = 5,
    KeepBestN = 3,
    SaveStudent = true
};

var manager = new DistillationCheckpointManager<double>(config);

// During training
for (int epoch = 0; epoch < 100; epoch++)
{
    // ... training code ...

    double validationLoss = EvaluateStudent();
    manager.SaveCheckpointIfNeeded(
        epoch: epoch,
        student: studentModel,
        metrics: new Dictionary<string, double> { { "validation_loss", validationLoss } }
    );
}

// Load best checkpoint
manager.LoadBestCheckpoint(studentModel);

Constructors

DistillationCheckpointManager(DistillationCheckpointConfig)

Initializes a new instance of the DistillationCheckpointManager class.

public DistillationCheckpointManager(DistillationCheckpointConfig config)

Parameters

config DistillationCheckpointConfig

Configuration for checkpoint management.

Remarks

For Advanced Users: This constructor requires explicit configuration. All parameters have recommended defaults in DistillationCheckpointConfig.

Exceptions

ArgumentNullException

Thrown when config is null.

Methods

GetAllCheckpoints()

Gets all saved checkpoint metadata as a readonly collection.

public IReadOnlyList<CheckpointMetadata> GetAllCheckpoints()

Returns

IReadOnlyList<CheckpointMetadata>

Readonly list of all checkpoint metadata.

Remarks

For Advanced Users: Provides read-only access to all saved checkpoints for custom queries.

GetBestCheckpoint()

Gets the checkpoint with the best metric value.

public CheckpointMetadata? GetBestCheckpoint()

Returns

CheckpointMetadata

Metadata of the best checkpoint, or null if no checkpoints exist.

Remarks

For Advanced Users: Returns the checkpoint with the best validation metric based on the configuration (e.g., lowest validation loss or highest accuracy).

GetCheckpointByEpoch(int)

Gets a checkpoint for a specific epoch.

public CheckpointMetadata? GetCheckpointByEpoch(int epoch)

Parameters

epoch int

The epoch number to find.

Returns

CheckpointMetadata

Metadata of the checkpoint at the specified epoch, or null if not found.

Remarks

For Advanced Users: Returns the checkpoint saved at a specific epoch number.

GetMostRecentCheckpoint()

Gets the most recently saved checkpoint.

public CheckpointMetadata? GetMostRecentCheckpoint()

Returns

CheckpointMetadata

Metadata of the most recent checkpoint, or null if no checkpoints exist.

Remarks

For Advanced Users: Useful for resuming interrupted training from the last saved state.

LoadBestCheckpoint(ICheckpointableModel?, ICheckpointableModel?)

Loads the best checkpoint based on the configured metric.

public CheckpointMetadata? LoadBestCheckpoint(ICheckpointableModel? student = null, ICheckpointableModel? teacher = null)

Parameters

student ICheckpointableModel

Student model to load into.

teacher ICheckpointableModel

Optional teacher model to load into.

Returns

CheckpointMetadata

Metadata of the loaded checkpoint, or null if no checkpoints exist.

Remarks

For Beginners: Call this after training to load the checkpoint with the best validation performance.

LoadCheckpoint(CheckpointMetadata, ICheckpointableModel?, ICheckpointableModel?)

Loads a specific checkpoint.

public void LoadCheckpoint(CheckpointMetadata metadata, ICheckpointableModel? student = null, ICheckpointableModel? teacher = null)

Parameters

metadata CheckpointMetadata

Metadata of the checkpoint to load.

student ICheckpointableModel

Student model to load into.

teacher ICheckpointableModel

Optional teacher model to load into.

SaveCheckpointIfNeeded(int, ICheckpointableModel?, ICheckpointableModel?, object?, Dictionary<string, double>?, int?, bool)

Saves a checkpoint if conditions are met.

public bool SaveCheckpointIfNeeded(int epoch, ICheckpointableModel? student = null, ICheckpointableModel? teacher = null, object? strategy = null, Dictionary<string, double>? metrics = null, int? batch = null, bool force = false)

Parameters

epoch int

Current epoch number.

student ICheckpointableModel

Student model to checkpoint (if SaveStudent = true).

teacher ICheckpointableModel

Teacher model to checkpoint (if SaveTeacher = true).

strategy object

Distillation strategy to checkpoint.

metrics Dictionary<string, double>

Training/validation metrics for this checkpoint.

batch int?

Optional batch number.

force bool

Force save regardless of schedule.

Returns

bool

True if checkpoint was saved.

Remarks

For Beginners: Call this method periodically during training. It will automatically decide whether to save based on your configuration.