Class ReduceOnPlateauScheduler
- Namespace
- AiDotNet.LearningRateSchedulers
- Assembly
- AiDotNet.dll
Reduces learning rate when a metric has stopped improving.
public class ReduceOnPlateauScheduler : LearningRateSchedulerBase, ILearningRateScheduler
- Inheritance
-
ReduceOnPlateauScheduler
- Implements
- Inherited Members
Examples
var scheduler = new ReduceOnPlateauScheduler(
baseLearningRate: 0.1,
factor: 0.1,
patience: 10,
mode: ReduceOnPlateauScheduler.Mode.Min
);
for (int epoch = 0; epoch < 100; epoch++)
{
Train(model, scheduler.CurrentLearningRate);
double valLoss = Validate(model);
scheduler.Step(valLoss); // Scheduler decides whether to reduce LR
}
Remarks
ReduceOnPlateau monitors a quantity (usually validation loss) and reduces the learning rate when no improvement is seen for a 'patience' number of evaluations. This is a reactive scheduler that adapts based on training progress rather than a fixed schedule.
For Beginners: Unlike other schedulers that follow a fixed schedule, this one watches your model's performance and only reduces the learning rate when training gets "stuck" (plateaus). If the model keeps improving, it keeps the learning rate the same. If improvement stops for a while (patience epochs), it reduces the learning rate to allow finer adjustments. Think of it like slowing down only when you notice you're not making progress.
This scheduler requires you to call the Step(metric) overload with the monitored value.
Constructors
ReduceOnPlateauScheduler(double, double, int, double, ThresholdMode, int, Mode, double)
Initializes a new instance of the ReduceOnPlateauScheduler class.
public ReduceOnPlateauScheduler(double baseLearningRate, double factor = 0.1, int patience = 10, double threshold = 0.0001, ReduceOnPlateauScheduler.ThresholdMode thresholdMode = ThresholdMode.Relative, int cooldown = 0, ReduceOnPlateauScheduler.Mode mode = Mode.Min, double minLearningRate = 0)
Parameters
baseLearningRatedoubleThe initial learning rate.
factordoubleFactor by which the learning rate is reduced. Default: 0.1
patienceintNumber of epochs with no improvement after which LR is reduced. Default: 10
thresholddoubleThreshold for measuring improvement. Default: 1e-4
thresholdModeReduceOnPlateauScheduler.ThresholdModeHow to compare with threshold. Default: Relative
cooldownintNumber of epochs to wait before resuming normal operation after LR reduction. Default: 0
modeReduceOnPlateauScheduler.ModeOptimization mode (min or max). Default: Min
minLearningRatedoubleMinimum learning rate floor. Default: 0
Properties
BestValue
Gets the best metric value seen so far.
public double BestValue { get; }
Property Value
Factor
Gets the reduction factor.
public double Factor { get; }
Property Value
NumBadEpochs
Gets the current number of bad epochs.
public int NumBadEpochs { get; }
Property Value
Patience
Gets the patience value.
public int Patience { get; }
Property Value
Methods
ComputeLearningRate(int)
Computes the learning rate for a given step.
protected override double ComputeLearningRate(int step)
Parameters
stepintThe step number.
Returns
- double
The computed learning rate.
GetState()
Gets the scheduler state for serialization/checkpointing.
public override Dictionary<string, object> GetState()
Returns
- Dictionary<string, object>
A dictionary containing the scheduler state.
LoadState(Dictionary<string, object>)
Loads the scheduler state from a checkpoint.
public override void LoadState(Dictionary<string, object> state)
Parameters
stateDictionary<string, object>The state dictionary to load from.
Reset()
Resets the scheduler to its initial state.
public override void Reset()
Step()
Advances the scheduler by one step and returns the new learning rate.
public override double Step()
Returns
- double
The updated learning rate for the next step.
Remarks
Note: For ReduceOnPlateau, the standard Step() without a metric does not reduce LR. Use Step(double metric) instead for proper functionality.
Step(double)
Steps the scheduler with a metric value.
public double Step(double metric)
Parameters
metricdoubleThe monitored metric value (e.g., validation loss).
Returns
- double
The current learning rate.