Table of Contents

Class ShardedOptimizerBase<T, TInput, TOutput>

Namespace
AiDotNet.DistributedTraining
Assembly
AiDotNet.dll

Provides base implementation for distributed optimizers with parameter sharding.

public abstract class ShardedOptimizerBase<T, TInput, TOutput> : IShardedOptimizer<T, TInput, TOutput>, IOptimizer<T, TInput, TOutput>, IModelSerializer

Type Parameters

T

The numeric type for operations

TInput

The input type for the model

TOutput

The output type for the model

Inheritance
ShardedOptimizerBase<T, TInput, TOutput>
Implements
IShardedOptimizer<T, TInput, TOutput>
IOptimizer<T, TInput, TOutput>
Derived
Inherited Members
Extension Methods

Remarks

This abstract class implements common functionality for all sharded optimizers, including optimizer wrapping, parameter synchronization, consensus-based early stopping, and serialization. Derived classes can customize the optimization strategy, implement different sharding approaches (FSDP, ZeRO, etc.), or add optimizer-specific features.

For Beginners: This is the foundation that all distributed optimizers build upon.

Think of this as a template for coordinating optimization across multiple computers or GPUs. It handles common tasks like:

  • Wrapping regular optimizers to work in distributed mode
  • Syncing parameters across all processes after updates
  • Making sure all processes agree on when to stop training
  • Saving and loading distributed optimizer state

Specific types of distributed optimizers (like data-parallel or ZeRO) inherit from this and add their own strategies. This prevents code duplication and ensures all distributed optimizers work consistently.

Constructors

ShardedOptimizerBase(IOptimizer<T, TInput, TOutput>, IShardingConfiguration<T>)

Initializes a new instance of the ShardedOptimizerBase class.

protected ShardedOptimizerBase(IOptimizer<T, TInput, TOutput> wrappedOptimizer, IShardingConfiguration<T> config)

Parameters

wrappedOptimizer IOptimizer<T, TInput, TOutput>

The optimizer to wrap with distributed capabilities

config IShardingConfiguration<T>

Configuration for sharding and communication

Remarks

This constructor wraps an existing optimizer with distributed training capabilities. It initializes the communication backend if needed and prepares for distributed optimization.

For Beginners: This constructor takes your regular optimizer and makes it distributed.

You provide:

  1. The optimizer you want to distribute (like Adam, SGD, etc.)
  2. Configuration that tells us how to distribute it

The constructor automatically:

  • Sets up communication if not already done
  • Prepares the optimizer for coordinated training
  • Ensures all processes can work together

Exceptions

ArgumentNullException

Thrown if optimizer or config is null

Fields

Config

The sharding configuration containing communication backend and settings.

protected readonly IShardingConfiguration<T> Config

Field Value

IShardingConfiguration<T>

NumOps

Provides numeric operations for type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

LastComputedGradients

Gets the gradients computed during the last optimization step.

public virtual Vector<T> LastComputedGradients { get; }

Property Value

Vector<T>

Remarks

Sharded optimizers delegate gradient access to the wrapped optimizer. If the wrapped optimizer is gradient-based, this will return the actual computed gradients. Otherwise, it returns an empty vector.

Rank

Gets the rank of this process in the distributed group.

public int Rank { get; }

Property Value

int

Remarks

For Beginners: Each process has a unique ID (rank). This tells you which process you are. Rank 0 is typically the "coordinator" process.

ShardingConfiguration

Gets the sharding configuration for this optimizer.

public IShardingConfiguration<T> ShardingConfiguration { get; }

Property Value

IShardingConfiguration<T>

WorldSize

Gets the total number of processes in the distributed group.

public int WorldSize { get; }

Property Value

int

Remarks

For Beginners: This is how many processes are working together to optimize the model. For example, if you have 4 GPUs, WorldSize would be 4.

WrappedOptimizer

Gets the underlying wrapped optimizer.

public IOptimizer<T, TInput, TOutput> WrappedOptimizer { get; }

Property Value

IOptimizer<T, TInput, TOutput>

Remarks

For Beginners: This is the original optimizer (like Adam, SGD, etc.) that we're adding distributed training capabilities to. Think of it as the "core brain" that we're helping to work across multiple processes.

WrappedOptimizerInternal

Protected access to wrapped optimizer for derived classes.

protected IOptimizer<T, TInput, TOutput> WrappedOptimizerInternal { get; }

Property Value

IOptimizer<T, TInput, TOutput>

Methods

ApplyGradients(Vector<T>, IFullModel<T, TInput, TOutput>)

Applies pre-computed gradients to a model's parameters.

public virtual IFullModel<T, TInput, TOutput> ApplyGradients(Vector<T> gradients, IFullModel<T, TInput, TOutput> model)

Parameters

gradients Vector<T>

The gradients to apply

model IFullModel<T, TInput, TOutput>

The model to update

Returns

IFullModel<T, TInput, TOutput>

The updated model

Remarks

Sharded optimizers delegate gradient application to the wrapped optimizer. If the wrapped optimizer is gradient-based, this will apply the gradients. Otherwise, throws NotSupportedException.

Exceptions

NotSupportedException

If the wrapped optimizer is not gradient-based

Deserialize(byte[])

Loads a previously serialized model from binary data.

public abstract void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.

For Beginners: This is like opening a saved file to continue your work.

When you call this method:

  • You provide the binary data (bytes) that was previously created by Serialize
  • The model rebuilds itself using this data
  • After deserializing, the model is exactly as it was when serialized
  • It's ready to make predictions without needing to be trained again

For example:

  • You download a pre-trained model file for detecting spam emails
  • You deserialize this file into your application
  • Immediately, your application can detect spam without any training
  • The model has all the knowledge that was built into it by its original creator

This is particularly useful when:

  • You want to use a model that took days to train
  • You need to deploy the same model across multiple devices
  • You're creating an application that non-technical users will use

Think of it like installing the brain of a trained expert directly into your application.

GetOptions()

Gets the configuration options for the optimization algorithm.

public virtual OptimizationAlgorithmOptions<T, TInput, TOutput> GetOptions()

Returns

OptimizationAlgorithmOptions<T, TInput, TOutput>

The configuration options for the optimization algorithm.

Remarks

These options control how the optimization algorithm behaves, including parameters like learning rate, maximum iterations, and convergence criteria.

For Beginners: This provides the "settings" or "rules" that the optimizer follows. Just like a recipe has instructions (bake at 350°F for 30 minutes), an optimizer has settings (learn at rate 0.01, stop after 1000 tries).

Common optimization options include:

  • Learning rate: How big of adjustments to make (step size)
  • Maximum iterations: How many attempts to make before giving up
  • Tolerance: How small an improvement is considered "good enough" to stop
  • Regularization: Settings that prevent the model from becoming too complex

LoadModel(string)

Loads the model from a file.

public virtual void LoadModel(string filePath)

Parameters

filePath string

The path to the file containing the saved model.

Remarks

This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.

For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.

Exceptions

FileNotFoundException

Thrown when the specified file does not exist.

IOException

Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.

Optimize(OptimizationInputData<T, TInput, TOutput>)

Performs the optimization process to find the best parameters for a model.

public abstract OptimizationResult<T, TInput, TOutput> Optimize(OptimizationInputData<T, TInput, TOutput> inputData)

Parameters

inputData OptimizationInputData<T, TInput, TOutput>

The data needed for optimization, including the objective function, initial parameters, and any constraints.

Returns

OptimizationResult<T, TInput, TOutput>

The result of the optimization process, including the optimized parameters and performance metrics.

Remarks

This method takes input data and attempts to find the optimal parameters that minimize or maximize the objective function.

For Beginners: This is where the actual "learning" happens. The optimizer looks at your data and tries different parameter values to find the ones that make your model perform best.

The process typically involves:

  1. Evaluating how well the current parameters perform
  2. Calculating how to change the parameters to improve performance
  3. Updating the parameters
  4. Repeating until the model performs well enough or reaches a maximum number of attempts

Reset()

Resets the optimizer state to prepare for a fresh optimization run.

public virtual void Reset()

Remarks

This method clears accumulated state including:

  • Model cache (prevents retrieving solutions from previous runs)
  • Fitness history (accumulated scores from previous optimizations)
  • Iteration history (logs from previous runs)
  • Adaptive parameters (learning rate, momentum reset to initial values)

For Beginners: Think of this like "clearing the whiteboard" before starting a new problem. When you run optimization multiple times (like during cross-validation), you want each run to start fresh without being influenced by previous runs. This method ensures that.

When to call Reset():

  • Before each cross-validation fold (ensures independent fold evaluations)
  • Before training the final model after cross-validation
  • Any time you want to reuse an optimizer for a completely new optimization task

Why this matters:

  • Prevents state contamination between independent training runs
  • Ensures reproducible results regardless of how many times you've used the optimizer
  • Avoids memory leaks from accumulated history
  • Maintains correct adaptive learning rate dynamics

SaveModel(string)

Saves the model to a file.

public virtual void SaveModel(string filePath)

Parameters

filePath string

The path where the model should be saved.

Remarks

This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.

For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.

Exceptions

IOException

Thrown when an I/O error occurs while writing to the file.

UnauthorizedAccessException

Thrown when the caller does not have the required permission to write to the specified file path.

Serialize()

Converts the current state of a machine learning model into a binary format.

public abstract byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.

For Beginners: This is like exporting your work to a file.

When you call this method:

  • The model's current state (all its learned patterns and parameters) is captured
  • This information is converted into a compact binary format (bytes)
  • You can then save these bytes to a file, database, or send them over a network

For example:

  • After training a model to recognize cats vs. dogs in images
  • You can serialize the model to save all its learned knowledge
  • Later, you can use this saved data to recreate the model exactly as it was
  • The recreated model will make the same predictions as the original

Think of it like taking a snapshot of your model's brain at a specific moment in time.

ShouldEarlyStop()

Determines whether the optimization process should stop early.

public virtual bool ShouldEarlyStop()

Returns

bool

True if the optimization process should stop early; otherwise, false.

Remarks

Early stopping is a technique to prevent overfitting by stopping the optimization process before it completes all iterations if certain conditions are met.

For Beginners: This is like knowing when to stop cooking - if the model is "done" (trained well enough), this method says "stop now" instead of continuing unnecessarily.

Common reasons for early stopping include:

  • The model's performance isn't improving anymore
  • The model's performance on validation data is getting worse (overfitting)
  • The changes in parameters are becoming very small (convergence)

Early stopping helps:

  • Save computation time
  • Prevent the model from becoming too specialized to the training data
  • Produce models that generalize better to new data

SynchronizeOptimizerState()

Synchronizes optimizer state (like momentum buffers) across all processes.

public abstract void SynchronizeOptimizerState()

Remarks

For Beginners: Some optimizers (like Adam) keep track of past gradients to make smarter updates. This method makes sure all processes have the same optimizer state, so they stay coordinated. It's like making sure all team members are reading from the same playbook.

SynchronizeParameters(IFullModel<T, TInput, TOutput>?)

Synchronizes model parameters across all processes using AllReduce with averaging.

protected virtual void SynchronizeParameters(IFullModel<T, TInput, TOutput>? model)

Parameters

model IFullModel<T, TInput, TOutput>

The model whose parameters to synchronize

Remarks

This method averages parameters across all processes, ensuring consistency. It's called after optimization steps to keep all processes synchronized.

For Beginners: After each process updates its model, we need to make sure everyone has the same parameters.

This method averages the parameters from all processes. For example, if GPU 0 calculated parameter value 1.0 and GPU 1 calculated 1.2, after sync both will have 1.1 (the average).