Class MixedPrecisionTrainingLoop<T>
- Namespace
- AiDotNet.MixedPrecision
- Assembly
- AiDotNet.dll
Implements mixed-precision training loop for neural networks following NVIDIA's approach.
public class MixedPrecisionTrainingLoop<T>
Type Parameters
TThe numeric type (must be float for mixed-precision).
- Inheritance
-
MixedPrecisionTrainingLoop<T>
- Inherited Members
Examples
// Create training loop
var trainLoop = new MixedPrecisionTrainingLoop<float>(
network,
optimizer,
lossFunction,
mixedPrecisionContext
);
// Train for one step
bool success = trainLoop.TrainStep(inputTensor, targetTensor);
if (!success)
{
Console.WriteLine("Step skipped due to gradient overflow");
}
Remarks
For Beginners: This class implements the complete mixed-precision training workflow:
- Cast weights to FP16 - Convert FP32 master weights to FP16 working weights
- Forward pass in FP16 - Fast computation using 16-bit precision
- Compute loss in FP32 - Calculate error using 32-bit precision for stability
- Scale loss - Multiply by large factor (e.g., 2^16) to prevent gradient underflow
- Backward pass in FP16 - Compute gradients in 16-bit precision
- Unscale and cast gradients to FP32 - Convert gradients back to 32-bit and divide by scale
- Check for overflow - Detect NaN/Inf and adjust loss scale if needed
- Update master weights in FP32 - Apply gradients to 32-bit master weights
This workflow provides 2-3x speedup on modern GPUs while maintaining model accuracy.
Constructors
MixedPrecisionTrainingLoop(NeuralNetworkBase<T>, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>, ILossFunction<T>, MixedPrecisionContext)
Initializes a new mixed-precision training loop.
public MixedPrecisionTrainingLoop(NeuralNetworkBase<T> network, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>> optimizer, ILossFunction<T> lossFunction, MixedPrecisionContext context)
Parameters
networkNeuralNetworkBase<T>The neural network to train.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>The optimizer to use for parameter updates.
lossFunctionILossFunction<T>The loss function to minimize.
contextMixedPrecisionContextThe mixed-precision training context.
Exceptions
- ArgumentException
Thrown when T is not float.
Properties
CurrentLossScale
Gets the current loss scale factor.
public double CurrentLossScale { get; }
Property Value
LastLoss
Gets the last computed loss value.
public T? LastLoss { get; }
Property Value
- T
SkippedSteps
Gets the number of steps skipped due to gradient overflow.
public int SkippedSteps { get; }
Property Value
TotalSteps
Gets the total number of training steps performed.
public int TotalSteps { get; }
Property Value
Methods
GetStatistics()
Gets statistics about the training process.
public string GetStatistics()
Returns
- string
A string containing training statistics.
TrainStep(Tensor<T>, Tensor<T>)
Performs one training step with mixed-precision.
public bool TrainStep(Tensor<T> input, Tensor<T> target)
Parameters
inputTensor<T>Input tensor.
targetTensor<T>Target tensor.
Returns
- bool
True if the step was successful; false if skipped due to gradient overflow.
Remarks
For Beginners: This method performs one complete training iteration: - Forward pass → Backward pass → Parameter update
If gradient overflow is detected (gradients become NaN or infinity), the step is skipped and the loss scale is automatically reduced. This is normal and expected occasionally.