Class MultiStepLRScheduler
- Namespace
- AiDotNet.LearningRateSchedulers
- Assembly
- AiDotNet.dll
Decays the learning rate by gamma at each milestone step.
public class MultiStepLRScheduler : LearningRateSchedulerBase, ILearningRateScheduler
- Inheritance
-
MultiStepLRScheduler
- Implements
- Inherited Members
Examples
// Decay at epochs 30, 80, and 120
var scheduler = new MultiStepLRScheduler(
baseLearningRate: 0.1,
milestones: new[] { 30, 80, 120 },
gamma: 0.1
);
Remarks
MultiStepLR decays the learning rate by gamma once the number of steps reaches one of the milestones. This allows for non-uniform decay schedules where you specify exactly when the learning rate should decrease.
For Beginners: Unlike StepLR which decays at regular intervals, MultiStepLR lets you specify exactly which steps to decay the learning rate at. For example, you might want to decay at epochs 30, 60, and 90, rather than every 30 epochs. This gives you more control over the training schedule.
This is useful when you know from experience or experimentation that certain epochs are good points to reduce the learning rate.
Constructors
MultiStepLRScheduler(double, int[], double, double)
Initializes a new instance of the MultiStepLRScheduler class.
public MultiStepLRScheduler(double baseLearningRate, int[] milestones, double gamma = 0.1, double minLearningRate = 0)
Parameters
baseLearningRatedoubleThe initial learning rate.
milestonesint[]List of step indices at which to decay the learning rate. Must be increasing.
gammadoubleMultiplicative factor of learning rate decay. Default: 0.1
minLearningRatedoubleMinimum learning rate floor. Default: 0
Exceptions
- ArgumentException
Thrown when milestones is empty or not in increasing order.
Properties
Gamma
Gets the multiplicative factor of learning rate decay.
public double Gamma { get; }
Property Value
Milestones
Gets the milestones.
public IReadOnlyList<int> Milestones { get; }
Property Value
Methods
ComputeLearningRate(int)
Computes the learning rate for a given step.
protected override double ComputeLearningRate(int step)
Parameters
stepintThe step number.
Returns
- double
The computed learning rate.
GetState()
Gets the scheduler state for serialization/checkpointing.
public override Dictionary<string, object> GetState()
Returns
- Dictionary<string, object>
A dictionary containing the scheduler state.