Class CosineAnnealingWarmRestartsScheduler
- Namespace
- AiDotNet.LearningRateSchedulers
- Assembly
- AiDotNet.dll
Sets the learning rate using cosine annealing with warm restarts.
public class CosineAnnealingWarmRestartsScheduler : LearningRateSchedulerBase, ILearningRateScheduler
- Inheritance
-
CosineAnnealingWarmRestartsScheduler
- Implements
- Inherited Members
Examples
// Warm restarts with initial period of 10, doubling each cycle
var scheduler = new CosineAnnealingWarmRestartsScheduler(
baseLearningRate: 0.1,
t0: 10,
tMult: 2,
etaMin: 0.001
);
Remarks
This scheduler implements the SGDR (Stochastic Gradient Descent with Warm Restarts) algorithm. It uses cosine annealing but periodically restarts the learning rate to the initial value, optionally increasing the period between restarts.
For Beginners: Imagine running a race in sprints instead of one continuous run. After each sprint (cycle), you rest (restart learning rate) and then sprint again. This "warm restart" approach helps the model escape local minima and often finds better solutions. The sprints can optionally get longer each time (controlled by T_mult), allowing for more fine-tuning in later cycles.
Based on the paper "SGDR: Stochastic Gradient Descent with Warm Restarts" by Loshchilov & Hutter.
Constructors
CosineAnnealingWarmRestartsScheduler(double, int, int, double)
Initializes a new instance of the CosineAnnealingWarmRestartsScheduler class.
public CosineAnnealingWarmRestartsScheduler(double baseLearningRate, int t0, int tMult = 1, double etaMin = 0)
Parameters
baseLearningRatedoubleThe initial (maximum) learning rate.
t0intNumber of steps for the first restart.
tMultintFactor to increase T after each restart. Default: 1 (constant period)
etaMindoubleMinimum learning rate. Default: 0
Exceptions
- ArgumentException
Thrown when t0 is not positive or tMult is less than 1.
Properties
CurrentCycle
Gets the current cycle number.
public int CurrentCycle { get; }
Property Value
EtaMin
Gets the minimum learning rate.
public double EtaMin { get; }
Property Value
T0
Gets the initial period.
public int T0 { get; }
Property Value
TMult
Gets the period multiplier.
public int TMult { get; }
Property Value
Methods
ComputeLearningRate(int)
Computes the learning rate for a given step.
protected override double ComputeLearningRate(int step)
Parameters
stepintThe step number.
Returns
- double
The computed learning rate.
GetState()
Gets the scheduler state for serialization/checkpointing.
public override Dictionary<string, object> GetState()
Returns
- Dictionary<string, object>
A dictionary containing the scheduler state.
LoadState(Dictionary<string, object>)
Loads the scheduler state from a checkpoint.
public override void LoadState(Dictionary<string, object> state)
Parameters
stateDictionary<string, object>The state dictionary to load from.
Reset()
Resets the scheduler to its initial state.
public override void Reset()
Step()
Advances the scheduler by one step and returns the new learning rate.
public override double Step()
Returns
- double
The updated learning rate for the next step.