Table of Contents

Class MultiStepLRScheduler

Namespace
AiDotNet.LearningRateSchedulers
Assembly
AiDotNet.dll

Decays the learning rate by gamma at each milestone step.

public class MultiStepLRScheduler : LearningRateSchedulerBase, ILearningRateScheduler
Inheritance
MultiStepLRScheduler
Implements
Inherited Members

Examples

// Decay at epochs 30, 80, and 120
var scheduler = new MultiStepLRScheduler(
    baseLearningRate: 0.1,
    milestones: new[] { 30, 80, 120 },
    gamma: 0.1
);

Remarks

MultiStepLR decays the learning rate by gamma once the number of steps reaches one of the milestones. This allows for non-uniform decay schedules where you specify exactly when the learning rate should decrease.

For Beginners: Unlike StepLR which decays at regular intervals, MultiStepLR lets you specify exactly which steps to decay the learning rate at. For example, you might want to decay at epochs 30, 60, and 90, rather than every 30 epochs. This gives you more control over the training schedule.

This is useful when you know from experience or experimentation that certain epochs are good points to reduce the learning rate.

Constructors

MultiStepLRScheduler(double, int[], double, double)

Initializes a new instance of the MultiStepLRScheduler class.

public MultiStepLRScheduler(double baseLearningRate, int[] milestones, double gamma = 0.1, double minLearningRate = 0)

Parameters

baseLearningRate double

The initial learning rate.

milestones int[]

List of step indices at which to decay the learning rate. Must be increasing.

gamma double

Multiplicative factor of learning rate decay. Default: 0.1

minLearningRate double

Minimum learning rate floor. Default: 0

Exceptions

ArgumentException

Thrown when milestones is empty or not in increasing order.

Properties

Gamma

Gets the multiplicative factor of learning rate decay.

public double Gamma { get; }

Property Value

double

Milestones

Gets the milestones.

public IReadOnlyList<int> Milestones { get; }

Property Value

IReadOnlyList<int>

Methods

ComputeLearningRate(int)

Computes the learning rate for a given step.

protected override double ComputeLearningRate(int step)

Parameters

step int

The step number.

Returns

double

The computed learning rate.

GetState()

Gets the scheduler state for serialization/checkpointing.

public override Dictionary<string, object> GetState()

Returns

Dictionary<string, object>

A dictionary containing the scheduler state.