Table of Contents

Class MixedPrecisionTrainingLoop<T>

Namespace
AiDotNet.MixedPrecision
Assembly
AiDotNet.dll

Implements mixed-precision training loop for neural networks following NVIDIA's approach.

public class MixedPrecisionTrainingLoop<T>

Type Parameters

T

The numeric type (must be float for mixed-precision).

Inheritance
MixedPrecisionTrainingLoop<T>
Inherited Members

Examples

// Create training loop
var trainLoop = new MixedPrecisionTrainingLoop<float>(
    network,
    optimizer,
    lossFunction,
    mixedPrecisionContext
);

// Train for one step
bool success = trainLoop.TrainStep(inputTensor, targetTensor);
if (!success)
{
    Console.WriteLine("Step skipped due to gradient overflow");
}

Remarks

For Beginners: This class implements the complete mixed-precision training workflow:

  1. Cast weights to FP16 - Convert FP32 master weights to FP16 working weights
  2. Forward pass in FP16 - Fast computation using 16-bit precision
  3. Compute loss in FP32 - Calculate error using 32-bit precision for stability
  4. Scale loss - Multiply by large factor (e.g., 2^16) to prevent gradient underflow
  5. Backward pass in FP16 - Compute gradients in 16-bit precision
  6. Unscale and cast gradients to FP32 - Convert gradients back to 32-bit and divide by scale
  7. Check for overflow - Detect NaN/Inf and adjust loss scale if needed
  8. Update master weights in FP32 - Apply gradients to 32-bit master weights

This workflow provides 2-3x speedup on modern GPUs while maintaining model accuracy.

Constructors

MixedPrecisionTrainingLoop(NeuralNetworkBase<T>, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>, ILossFunction<T>, MixedPrecisionContext)

Initializes a new mixed-precision training loop.

public MixedPrecisionTrainingLoop(NeuralNetworkBase<T> network, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>> optimizer, ILossFunction<T> lossFunction, MixedPrecisionContext context)

Parameters

network NeuralNetworkBase<T>

The neural network to train.

optimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

The optimizer to use for parameter updates.

lossFunction ILossFunction<T>

The loss function to minimize.

context MixedPrecisionContext

The mixed-precision training context.

Exceptions

ArgumentException

Thrown when T is not float.

Properties

CurrentLossScale

Gets the current loss scale factor.

public double CurrentLossScale { get; }

Property Value

double

LastLoss

Gets the last computed loss value.

public T? LastLoss { get; }

Property Value

T

SkippedSteps

Gets the number of steps skipped due to gradient overflow.

public int SkippedSteps { get; }

Property Value

int

TotalSteps

Gets the total number of training steps performed.

public int TotalSteps { get; }

Property Value

int

Methods

GetStatistics()

Gets statistics about the training process.

public string GetStatistics()

Returns

string

A string containing training statistics.

TrainStep(Tensor<T>, Tensor<T>)

Performs one training step with mixed-precision.

public bool TrainStep(Tensor<T> input, Tensor<T> target)

Parameters

input Tensor<T>

Input tensor.

target Tensor<T>

Target tensor.

Returns

bool

True if the step was successful; false if skipped due to gradient overflow.

Remarks

For Beginners: This method performs one complete training iteration: - Forward pass → Backward pass → Parameter update

If gradient overflow is detected (gradients become NaN or infinity), the step is skipped and the loss scale is automatically reduced. This is normal and expected occasionally.