Table of Contents

Class FSDPOptimizer<T, TInput, TOutput>

Namespace
AiDotNet.DistributedTraining
Assembly
AiDotNet.dll

Implements FSDP (Fully Sharded Data Parallel) optimizer wrapper that coordinates optimization across multiple processes.

public class FSDPOptimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>, IShardedOptimizer<T, TInput, TOutput>, IOptimizer<T, TInput, TOutput>, IModelSerializer

Type Parameters

T

The numeric type

TInput

The input type for the model

TOutput

The output type for the model

Inheritance
ShardedOptimizerBase<T, TInput, TOutput>
FSDPOptimizer<T, TInput, TOutput>
Implements
IShardedOptimizer<T, TInput, TOutput>
IOptimizer<T, TInput, TOutput>
Derived
Inherited Members
Extension Methods

Remarks

Strategy Overview: FSDP optimizer works in conjunction with FSDPModel to provide full sharding of optimizer states. This means momentum buffers, variance estimates, and all other optimizer-specific state are sharded across processes, minimizing memory usage while maintaining training effectiveness.

For Beginners: This class wraps any existing optimizer (like Adam, SGD, etc.) and makes it work with FSDP strategy across multiple GPUs or machines. It automatically handles: - Synchronizing gradients across all processes - Sharding optimizer states (momentum, variance) to save memory - Coordinating parameter updates - Ensuring all processes stay in sync

Think of it like a team of coaches working together - each coach has their own expertise (the wrapped optimizer), but they share only the essential information and keep their detailed notes (optimizer states) private to save space.

Use Cases: - Training very large models with optimizers that have significant state (Adam, RMSprop) - Maximizing memory efficiency when using stateful optimizers - Scaling to hundreds or thousands of GPUs

Trade-offs: - Memory: Excellent - shards optimizer states across processes - Communication: Moderate - syncs gradients and occasional state synchronization - Complexity: Moderate - automatic state sharding - Best for: Large models with stateful optimizers (Adam, RMSprop, etc.)

Example:

// Original optimizer
var optimizer = new AdamOptimizer<double, Tensor<double>, Tensor<double>>(model, options);

// Wrap it for FSDP distributed training var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4); var config = new ShardingConfiguration<double>(backend); var fsdpOptimizer = new FSDPOptimizer<double, Tensor<double>, Tensor<double>>( optimizer, config);

// Now optimize as usual - FSDP magic happens automatically! var result = fsdpOptimizer.Optimize(inputData);

Constructors

FSDPOptimizer(IOptimizer<T, TInput, TOutput>, IShardingConfiguration<T>)

Creates a new FSDP optimizer wrapping an existing optimizer.

public FSDPOptimizer(IOptimizer<T, TInput, TOutput> wrappedOptimizer, IShardingConfiguration<T> config)

Parameters

wrappedOptimizer IOptimizer<T, TInput, TOutput>

The optimizer to wrap with FSDP capabilities

config IShardingConfiguration<T>

Configuration for sharding and communication

Remarks

For Beginners: This constructor takes your existing optimizer and makes it distributed using FSDP strategy. You provide: 1. The optimizer you want to make distributed 2. A configuration that tells us how to do the distribution

The optimizer will automatically synchronize across all processes during optimization and shard optimizer states to minimize memory usage.

Exceptions

ArgumentNullException

Thrown if optimizer or config is null

Methods

Deserialize(byte[])

Loads a previously serialized model from binary data.

public override void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.

For Beginners: This is like opening a saved file to continue your work.

When you call this method:

  • You provide the binary data (bytes) that was previously created by Serialize
  • The model rebuilds itself using this data
  • After deserializing, the model is exactly as it was when serialized
  • It's ready to make predictions without needing to be trained again

For example:

  • You download a pre-trained model file for detecting spam emails
  • You deserialize this file into your application
  • Immediately, your application can detect spam without any training
  • The model has all the knowledge that was built into it by its original creator

This is particularly useful when:

  • You want to use a model that took days to train
  • You need to deploy the same model across multiple devices
  • You're creating an application that non-technical users will use

Think of it like installing the brain of a trained expert directly into your application.

GetOptions()

Gets the configuration options for the optimization algorithm.

public override OptimizationAlgorithmOptions<T, TInput, TOutput> GetOptions()

Returns

OptimizationAlgorithmOptions<T, TInput, TOutput>

The configuration options for the optimization algorithm.

Remarks

These options control how the optimization algorithm behaves, including parameters like learning rate, maximum iterations, and convergence criteria.

For Beginners: This provides the "settings" or "rules" that the optimizer follows. Just like a recipe has instructions (bake at 350°F for 30 minutes), an optimizer has settings (learn at rate 0.01, stop after 1000 tries).

Common optimization options include:

  • Learning rate: How big of adjustments to make (step size)
  • Maximum iterations: How many attempts to make before giving up
  • Tolerance: How small an improvement is considered "good enough" to stop
  • Regularization: Settings that prevent the model from becoming too complex

LoadModel(string)

Loads the model from a file.

public override void LoadModel(string filePath)

Parameters

filePath string

The path to the file containing the saved model.

Remarks

This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.

For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.

Exceptions

FileNotFoundException

Thrown when the specified file does not exist.

IOException

Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.

Optimize(OptimizationInputData<T, TInput, TOutput>)

Performs the optimization process to find the best parameters for a model.

public override OptimizationResult<T, TInput, TOutput> Optimize(OptimizationInputData<T, TInput, TOutput> inputData)

Parameters

inputData OptimizationInputData<T, TInput, TOutput>

The data needed for optimization, including the objective function, initial parameters, and any constraints.

Returns

OptimizationResult<T, TInput, TOutput>

The result of the optimization process, including the optimized parameters and performance metrics.

Remarks

This method takes input data and attempts to find the optimal parameters that minimize or maximize the objective function.

For Beginners: This is where the actual "learning" happens. The optimizer looks at your data and tries different parameter values to find the ones that make your model perform best.

The process typically involves:

  1. Evaluating how well the current parameters perform
  2. Calculating how to change the parameters to improve performance
  3. Updating the parameters
  4. Repeating until the model performs well enough or reaches a maximum number of attempts

SaveModel(string)

Saves the model to a file.

public override void SaveModel(string filePath)

Parameters

filePath string

The path where the model should be saved.

Remarks

This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.

For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.

Exceptions

IOException

Thrown when an I/O error occurs while writing to the file.

UnauthorizedAccessException

Thrown when the caller does not have the required permission to write to the specified file path.

Serialize()

Converts the current state of a machine learning model into a binary format.

public override byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.

For Beginners: This is like exporting your work to a file.

When you call this method:

  • The model's current state (all its learned patterns and parameters) is captured
  • This information is converted into a compact binary format (bytes)
  • You can then save these bytes to a file, database, or send them over a network

For example:

  • After training a model to recognize cats vs. dogs in images
  • You can serialize the model to save all its learned knowledge
  • Later, you can use this saved data to recreate the model exactly as it was
  • The recreated model will make the same predictions as the original

Think of it like taking a snapshot of your model's brain at a specific moment in time.

ShouldEarlyStop()

Determines whether the optimization process should stop early.

public override bool ShouldEarlyStop()

Returns

bool

True if the optimization process should stop early; otherwise, false.

Remarks

Early stopping is a technique to prevent overfitting by stopping the optimization process before it completes all iterations if certain conditions are met.

For Beginners: This is like knowing when to stop cooking - if the model is "done" (trained well enough), this method says "stop now" instead of continuing unnecessarily.

Common reasons for early stopping include:

  • The model's performance isn't improving anymore
  • The model's performance on validation data is getting worse (overfitting)
  • The changes in parameters are becoming very small (convergence)

Early stopping helps:

  • Save computation time
  • Prevent the model from becoming too specialized to the training data
  • Produce models that generalize better to new data

SynchronizeOptimizerState()

Synchronizes optimizer state (like momentum buffers) across all processes.

public override void SynchronizeOptimizerState()

Remarks

For Beginners: Some optimizers (like Adam) keep track of past gradients to make smarter updates. This method makes sure all processes have the same optimizer state, so they stay coordinated. It's like making sure all team members are reading from the same playbook.