Table of Contents

Class CosineAnnealingWarmRestartsScheduler

Namespace
AiDotNet.LearningRateSchedulers
Assembly
AiDotNet.dll

Sets the learning rate using cosine annealing with warm restarts.

public class CosineAnnealingWarmRestartsScheduler : LearningRateSchedulerBase, ILearningRateScheduler
Inheritance
CosineAnnealingWarmRestartsScheduler
Implements
Inherited Members

Examples

// Warm restarts with initial period of 10, doubling each cycle
var scheduler = new CosineAnnealingWarmRestartsScheduler(
    baseLearningRate: 0.1,
    t0: 10,
    tMult: 2,
    etaMin: 0.001
);

Remarks

This scheduler implements the SGDR (Stochastic Gradient Descent with Warm Restarts) algorithm. It uses cosine annealing but periodically restarts the learning rate to the initial value, optionally increasing the period between restarts.

For Beginners: Imagine running a race in sprints instead of one continuous run. After each sprint (cycle), you rest (restart learning rate) and then sprint again. This "warm restart" approach helps the model escape local minima and often finds better solutions. The sprints can optionally get longer each time (controlled by T_mult), allowing for more fine-tuning in later cycles.

Based on the paper "SGDR: Stochastic Gradient Descent with Warm Restarts" by Loshchilov & Hutter.

Constructors

CosineAnnealingWarmRestartsScheduler(double, int, int, double)

Initializes a new instance of the CosineAnnealingWarmRestartsScheduler class.

public CosineAnnealingWarmRestartsScheduler(double baseLearningRate, int t0, int tMult = 1, double etaMin = 0)

Parameters

baseLearningRate double

The initial (maximum) learning rate.

t0 int

Number of steps for the first restart.

tMult int

Factor to increase T after each restart. Default: 1 (constant period)

etaMin double

Minimum learning rate. Default: 0

Exceptions

ArgumentException

Thrown when t0 is not positive or tMult is less than 1.

Properties

CurrentCycle

Gets the current cycle number.

public int CurrentCycle { get; }

Property Value

int

EtaMin

Gets the minimum learning rate.

public double EtaMin { get; }

Property Value

double

T0

Gets the initial period.

public int T0 { get; }

Property Value

int

TMult

Gets the period multiplier.

public int TMult { get; }

Property Value

int

Methods

ComputeLearningRate(int)

Computes the learning rate for a given step.

protected override double ComputeLearningRate(int step)

Parameters

step int

The step number.

Returns

double

The computed learning rate.

GetState()

Gets the scheduler state for serialization/checkpointing.

public override Dictionary<string, object> GetState()

Returns

Dictionary<string, object>

A dictionary containing the scheduler state.

LoadState(Dictionary<string, object>)

Loads the scheduler state from a checkpoint.

public override void LoadState(Dictionary<string, object> state)

Parameters

state Dictionary<string, object>

The state dictionary to load from.

Reset()

Resets the scheduler to its initial state.

public override void Reset()

Step()

Advances the scheduler by one step and returns the new learning rate.

public override double Step()

Returns

double

The updated learning rate for the next step.