Class CosineAnnealingLRScheduler
- Namespace
- AiDotNet.LearningRateSchedulers
- Assembly
- AiDotNet.dll
Sets the learning rate using a cosine annealing schedule.
public class CosineAnnealingLRScheduler : LearningRateSchedulerBase, ILearningRateScheduler
- Inheritance
-
CosineAnnealingLRScheduler
- Implements
- Inherited Members
Examples
// Cosine annealing over 100 epochs
var scheduler = new CosineAnnealingLRScheduler(
baseLearningRate: 0.1,
tMax: 100,
etaMin: 0.001
);
Remarks
CosineAnnealingLR uses a cosine function to smoothly decrease the learning rate from the initial value to a minimum value over a specified number of steps. This is widely used in modern deep learning and often outperforms step-based decay schedules.
For Beginners: Instead of making sudden drops in learning rate, cosine annealing provides a smooth, curved decrease that follows the shape of a cosine wave. The learning rate starts high, decreases slowly at first, then more rapidly in the middle, and finally slows down again as it approaches the minimum. This smooth transition often leads to better model performance than abrupt changes.
Formula: lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + cos(π * step / T_max))
Constructors
CosineAnnealingLRScheduler(double, int, double)
Initializes a new instance of the CosineAnnealingLRScheduler class.
public CosineAnnealingLRScheduler(double baseLearningRate, int tMax, double etaMin = 0)
Parameters
baseLearningRatedoubleThe initial (maximum) learning rate.
tMaxintMaximum number of steps (typically total epochs or iterations).
etaMindoubleMinimum learning rate. Default: 0
Exceptions
- ArgumentException
Thrown when tMax is not positive.
Properties
EtaMin
Gets the minimum learning rate.
public double EtaMin { get; }
Property Value
TMax
Gets the maximum number of steps.
public int TMax { get; }
Property Value
Methods
ComputeLearningRate(int)
Computes the learning rate for a given step.
protected override double ComputeLearningRate(int step)
Parameters
stepintThe step number.
Returns
- double
The computed learning rate.
GetState()
Gets the scheduler state for serialization/checkpointing.
public override Dictionary<string, object> GetState()
Returns
- Dictionary<string, object>
A dictionary containing the scheduler state.