Table of Contents

Class StepScheduler<T>

Namespace
AiDotNet.CurriculumLearning.Schedulers
Assembly
AiDotNet.dll

Curriculum scheduler with discrete step-based progression.

public class StepScheduler<T> : CurriculumSchedulerBase<T>, ICurriculumScheduler<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
StepScheduler<T>
Implements
Inherited Members

Remarks

For Beginners: This scheduler divides training into discrete phases, with the data fraction jumping at specific epochs rather than changing continuously.

Example: With 3 steps over 12 epochs:

  • Epochs 0-3: 33% easiest samples
  • Epochs 4-7: 66% easiest samples
  • Epochs 8-11: 100% of all samples

Best For:

  • Clear curriculum phases with distinct difficulty levels
  • When you want model to fully adapt to each phase before progressing
  • Educational datasets with natural difficulty tiers

Constructors

StepScheduler(int, IEnumerable<T>)

Initializes a new instance with custom step fractions.

public StepScheduler(int totalEpochs, IEnumerable<T> stepFractions)

Parameters

totalEpochs int

Total number of training epochs.

stepFractions IEnumerable<T>

Custom data fractions for each step.

StepScheduler(int, int, T?, T?)

Initializes a new instance with uniform steps.

public StepScheduler(int totalEpochs, int numSteps = 5, T? minFraction = default, T? maxFraction = default)

Parameters

totalEpochs int

Total number of training epochs.

numSteps int

Number of curriculum phases.

minFraction T

Initial data fraction (default 0.1).

maxFraction T

Final data fraction (default 1.0).

Properties

Name

Gets the name of this scheduler.

public override string Name { get; }

Property Value

string

TotalPhases

Gets the total number of phases (steps) in this scheduler.

public override int TotalPhases { get; }

Property Value

int

Methods

GetDataFraction()

Gets the current data fraction based on the current step.

public override T GetDataFraction()

Returns

T

GetStatistics()

Gets scheduler-specific statistics.

public override Dictionary<string, object> GetStatistics()

Returns

Dictionary<string, object>