Table of Contents

Class MixedPrecisionContext

Namespace
AiDotNet.MixedPrecision
Assembly
AiDotNet.dll

Manages master weights (FP32) and working weights (FP16) for mixed-precision training.

public class MixedPrecisionContext : IDisposable
Inheritance
MixedPrecisionContext
Implements
Inherited Members

Remarks

For Beginners: Mixed-precision training uses two copies of model parameters:

  1. Master Weights (FP32):

    • High-precision copy of all parameters
    • Used for parameter updates to maintain accuracy
    • Stored in memory but not used for forward/backward passes
  2. Working Weights (FP16):

    • Low-precision copy used for computation
    • Used in forward and backward passes
    • Faster and uses less memory
    • Synced from master weights before each forward pass

The workflow:

  1. Cast master weights (FP32) to working weights (FP16)
  2. Forward pass using FP16 weights → faster, less memory
  3. Backward pass in FP16 → computes FP16 gradients
  4. Cast gradients to FP32 and unscale
  5. Update master weights in FP32 → maintains precision
  6. Repeat from step 1

This approach combines the speed of FP16 with the numerical stability of FP32.

Technical Details: The context maintains: - Dictionary mapping parameter names to FP32 master copies - Dictionary mapping parameter names to FP16 working copies - Synchronization methods to cast between precisions - Integration with LossScaler for gradient management

Constructors

MixedPrecisionContext(MixedPrecisionConfig?)

Initializes a new mixed-precision training context.

public MixedPrecisionContext(MixedPrecisionConfig? config = null)

Parameters

config MixedPrecisionConfig

Configuration for mixed-precision training (optional, uses defaults if null).

Remarks

For Beginners: Create one context per neural network model. The context will manage all the parameter conversions automatically.

Properties

Config

Configuration for mixed-precision training.

public MixedPrecisionConfig Config { get; }

Property Value

MixedPrecisionConfig

IsInitialized

Whether the context has been initialized with parameters.

public bool IsInitialized { get; }

Property Value

bool

LossScaler

Loss scaler for gradient scaling and overflow detection.

public LossScaler<float> LossScaler { get; }

Property Value

LossScaler<float>

ParameterCount

Number of parameters managed by this context.

public int ParameterCount { get; }

Property Value

int

ParameterNames

Gets the names of all parameters being managed.

public IReadOnlyCollection<string> ParameterNames { get; }

Property Value

IReadOnlyCollection<string>

Methods

CastWeightsToFP16()

Converts master weights (FP32) to working weights (FP16) for forward pass.

public void CastWeightsToFP16()

Remarks

For Beginners: Call this before each forward pass to sync the FP16 working weights from the FP32 master weights. This ensures the working weights reflect the latest parameter updates.

Dispose()

Disposes of the context and releases resources.

public void Dispose()

GetMasterWeights(string)

Gets the master weights (FP32) for a parameter group.

public Vector<float> GetMasterWeights(string parameterName = "params")

Parameters

parameterName string

Name of the parameter group.

Returns

Vector<float>

The master weights in FP32.

GetWorkingWeights(string)

Gets the working weights (FP16) for a parameter group.

public Vector<Half> GetWorkingWeights(string parameterName = "params")

Parameters

parameterName string

Name of the parameter group.

Returns

Vector<Half>

The working weights in FP16.

Initialize(Vector<float>, string)

Initializes the context with model parameters.

public void Initialize(Vector<float> parameters, string parameterName = "params")

Parameters

parameters Vector<float>

The model parameters in FP32.

parameterName string

Optional name for the parameters (default: "params").

Remarks

For Beginners: Call this once after creating your model to register the parameters with the mixed-precision context. The parameters you pass should be in FP32.

Initialize(Dictionary<string, Vector<float>>)

Initializes the context with multiple named parameter groups.

public void Initialize(Dictionary<string, Vector<float>> namedParameters)

Parameters

namedParameters Dictionary<string, Vector<float>>

Dictionary mapping parameter names to parameter vectors.

Remarks

For Beginners: Use this when you want to manage multiple parameter groups separately, for example, different layers or different types of parameters (weights vs. biases).

PrepareGradientsForUpdate(Vector<Half>, out Vector<float>)

Converts FP16 gradients to FP32, unscales them, and checks for overflow.

public bool PrepareGradientsForUpdate(Vector<Half> gradientsHalf, out Vector<float> gradientsFloat)

Parameters

gradientsHalf Vector<Half>

The scaled gradients in FP16.

gradientsFloat Vector<float>

Output: unscaled gradients in FP32 (if no overflow).

Returns

bool

True if gradients are valid and update should proceed; false if overflow detected.

Remarks

For Beginners: This is a key method in the mixed-precision training loop. It performs three steps: 1. Converts FP16 gradients to FP32 (lossless) 2. Unscales the gradients (divides by loss scale) 3. Checks for NaN/infinity and adjusts loss scale if needed

If this returns false, you should skip the parameter update for this iteration.

Reset()

Resets the context, clearing all weights and statistics.

public void Reset()

ToString()

Gets a summary of the context's current state.

public override string ToString()

Returns

string

A string describing the current state.

UpdateMasterWeights(Vector<float>, float, string)

Updates master weights with FP32 gradients after unscaling.

public void UpdateMasterWeights(Vector<float> gradients, float learningRate, string parameterName = "params")

Parameters

gradients Vector<float>

The unscaled gradients in FP32.

learningRate float

Learning rate for the update.

parameterName string

Name of the parameter group to update.

Remarks

For Beginners: This applies the gradient descent update to the FP32 master weights. The formula is: weights = weights - learningRate * gradients