Class ShardedOptimizerBase<T, TInput, TOutput>
- Namespace
- AiDotNet.DistributedTraining
- Assembly
- AiDotNet.dll
Provides base implementation for distributed optimizers with parameter sharding.
public abstract class ShardedOptimizerBase<T, TInput, TOutput> : IShardedOptimizer<T, TInput, TOutput>, IOptimizer<T, TInput, TOutput>, IModelSerializer
Type Parameters
TThe numeric type for operations
TInputThe input type for the model
TOutputThe output type for the model
- Inheritance
-
ShardedOptimizerBase<T, TInput, TOutput>
- Implements
-
IShardedOptimizer<T, TInput, TOutput>IOptimizer<T, TInput, TOutput>
- Derived
- Inherited Members
- Extension Methods
Remarks
This abstract class implements common functionality for all sharded optimizers, including optimizer wrapping, parameter synchronization, consensus-based early stopping, and serialization. Derived classes can customize the optimization strategy, implement different sharding approaches (FSDP, ZeRO, etc.), or add optimizer-specific features.
For Beginners: This is the foundation that all distributed optimizers build upon.
Think of this as a template for coordinating optimization across multiple computers or GPUs. It handles common tasks like:
- Wrapping regular optimizers to work in distributed mode
- Syncing parameters across all processes after updates
- Making sure all processes agree on when to stop training
- Saving and loading distributed optimizer state
Specific types of distributed optimizers (like data-parallel or ZeRO) inherit from this and add their own strategies. This prevents code duplication and ensures all distributed optimizers work consistently.
Constructors
ShardedOptimizerBase(IOptimizer<T, TInput, TOutput>, IShardingConfiguration<T>)
Initializes a new instance of the ShardedOptimizerBase class.
protected ShardedOptimizerBase(IOptimizer<T, TInput, TOutput> wrappedOptimizer, IShardingConfiguration<T> config)
Parameters
wrappedOptimizerIOptimizer<T, TInput, TOutput>The optimizer to wrap with distributed capabilities
configIShardingConfiguration<T>Configuration for sharding and communication
Remarks
This constructor wraps an existing optimizer with distributed training capabilities. It initializes the communication backend if needed and prepares for distributed optimization.
For Beginners: This constructor takes your regular optimizer and makes it distributed.
You provide:
- The optimizer you want to distribute (like Adam, SGD, etc.)
- Configuration that tells us how to distribute it
The constructor automatically:
- Sets up communication if not already done
- Prepares the optimizer for coordinated training
- Ensures all processes can work together
Exceptions
- ArgumentNullException
Thrown if optimizer or config is null
Fields
Config
The sharding configuration containing communication backend and settings.
protected readonly IShardingConfiguration<T> Config
Field Value
NumOps
Provides numeric operations for type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Properties
LastComputedGradients
Gets the gradients computed during the last optimization step.
public virtual Vector<T> LastComputedGradients { get; }
Property Value
- Vector<T>
Remarks
Sharded optimizers delegate gradient access to the wrapped optimizer. If the wrapped optimizer is gradient-based, this will return the actual computed gradients. Otherwise, it returns an empty vector.
Rank
Gets the rank of this process in the distributed group.
public int Rank { get; }
Property Value
Remarks
For Beginners: Each process has a unique ID (rank). This tells you which process you are. Rank 0 is typically the "coordinator" process.
ShardingConfiguration
Gets the sharding configuration for this optimizer.
public IShardingConfiguration<T> ShardingConfiguration { get; }
Property Value
WorldSize
Gets the total number of processes in the distributed group.
public int WorldSize { get; }
Property Value
Remarks
For Beginners: This is how many processes are working together to optimize the model. For example, if you have 4 GPUs, WorldSize would be 4.
WrappedOptimizer
Gets the underlying wrapped optimizer.
public IOptimizer<T, TInput, TOutput> WrappedOptimizer { get; }
Property Value
- IOptimizer<T, TInput, TOutput>
Remarks
For Beginners: This is the original optimizer (like Adam, SGD, etc.) that we're adding distributed training capabilities to. Think of it as the "core brain" that we're helping to work across multiple processes.
WrappedOptimizerInternal
Protected access to wrapped optimizer for derived classes.
protected IOptimizer<T, TInput, TOutput> WrappedOptimizerInternal { get; }
Property Value
- IOptimizer<T, TInput, TOutput>
Methods
ApplyGradients(Vector<T>, IFullModel<T, TInput, TOutput>)
Applies pre-computed gradients to a model's parameters.
public virtual IFullModel<T, TInput, TOutput> ApplyGradients(Vector<T> gradients, IFullModel<T, TInput, TOutput> model)
Parameters
gradientsVector<T>The gradients to apply
modelIFullModel<T, TInput, TOutput>The model to update
Returns
- IFullModel<T, TInput, TOutput>
The updated model
Remarks
Sharded optimizers delegate gradient application to the wrapped optimizer. If the wrapped optimizer is gradient-based, this will apply the gradients. Otherwise, throws NotSupportedException.
Exceptions
- NotSupportedException
If the wrapped optimizer is not gradient-based
Deserialize(byte[])
Loads a previously serialized model from binary data.
public abstract void Deserialize(byte[] data)
Parameters
databyte[]The byte array containing the serialized model data.
Remarks
This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.
For Beginners: This is like opening a saved file to continue your work.
When you call this method:
- You provide the binary data (bytes) that was previously created by Serialize
- The model rebuilds itself using this data
- After deserializing, the model is exactly as it was when serialized
- It's ready to make predictions without needing to be trained again
For example:
- You download a pre-trained model file for detecting spam emails
- You deserialize this file into your application
- Immediately, your application can detect spam without any training
- The model has all the knowledge that was built into it by its original creator
This is particularly useful when:
- You want to use a model that took days to train
- You need to deploy the same model across multiple devices
- You're creating an application that non-technical users will use
Think of it like installing the brain of a trained expert directly into your application.
GetOptions()
Gets the configuration options for the optimization algorithm.
public virtual OptimizationAlgorithmOptions<T, TInput, TOutput> GetOptions()
Returns
- OptimizationAlgorithmOptions<T, TInput, TOutput>
The configuration options for the optimization algorithm.
Remarks
These options control how the optimization algorithm behaves, including parameters like learning rate, maximum iterations, and convergence criteria.
For Beginners: This provides the "settings" or "rules" that the optimizer follows. Just like a recipe has instructions (bake at 350°F for 30 minutes), an optimizer has settings (learn at rate 0.01, stop after 1000 tries).
Common optimization options include:
- Learning rate: How big of adjustments to make (step size)
- Maximum iterations: How many attempts to make before giving up
- Tolerance: How small an improvement is considered "good enough" to stop
- Regularization: Settings that prevent the model from becoming too complex
LoadModel(string)
Loads the model from a file.
public virtual void LoadModel(string filePath)
Parameters
filePathstringThe path to the file containing the saved model.
Remarks
This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.
For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.
Exceptions
- FileNotFoundException
Thrown when the specified file does not exist.
- IOException
Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.
Optimize(OptimizationInputData<T, TInput, TOutput>)
Performs the optimization process to find the best parameters for a model.
public abstract OptimizationResult<T, TInput, TOutput> Optimize(OptimizationInputData<T, TInput, TOutput> inputData)
Parameters
inputDataOptimizationInputData<T, TInput, TOutput>The data needed for optimization, including the objective function, initial parameters, and any constraints.
Returns
- OptimizationResult<T, TInput, TOutput>
The result of the optimization process, including the optimized parameters and performance metrics.
Remarks
This method takes input data and attempts to find the optimal parameters that minimize or maximize the objective function.
For Beginners: This is where the actual "learning" happens. The optimizer looks at your data and tries different parameter values to find the ones that make your model perform best.
The process typically involves:
- Evaluating how well the current parameters perform
- Calculating how to change the parameters to improve performance
- Updating the parameters
- Repeating until the model performs well enough or reaches a maximum number of attempts
Reset()
Resets the optimizer state to prepare for a fresh optimization run.
public virtual void Reset()
Remarks
This method clears accumulated state including:
- Model cache (prevents retrieving solutions from previous runs)
- Fitness history (accumulated scores from previous optimizations)
- Iteration history (logs from previous runs)
- Adaptive parameters (learning rate, momentum reset to initial values)
For Beginners: Think of this like "clearing the whiteboard" before starting a new problem. When you run optimization multiple times (like during cross-validation), you want each run to start fresh without being influenced by previous runs. This method ensures that.
When to call Reset():
- Before each cross-validation fold (ensures independent fold evaluations)
- Before training the final model after cross-validation
- Any time you want to reuse an optimizer for a completely new optimization task
Why this matters:
- Prevents state contamination between independent training runs
- Ensures reproducible results regardless of how many times you've used the optimizer
- Avoids memory leaks from accumulated history
- Maintains correct adaptive learning rate dynamics
SaveModel(string)
Saves the model to a file.
public virtual void SaveModel(string filePath)
Parameters
filePathstringThe path where the model should be saved.
Remarks
This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.
For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.
Exceptions
- IOException
Thrown when an I/O error occurs while writing to the file.
- UnauthorizedAccessException
Thrown when the caller does not have the required permission to write to the specified file path.
Serialize()
Converts the current state of a machine learning model into a binary format.
public abstract byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model data.
Remarks
This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.
For Beginners: This is like exporting your work to a file.
When you call this method:
- The model's current state (all its learned patterns and parameters) is captured
- This information is converted into a compact binary format (bytes)
- You can then save these bytes to a file, database, or send them over a network
For example:
- After training a model to recognize cats vs. dogs in images
- You can serialize the model to save all its learned knowledge
- Later, you can use this saved data to recreate the model exactly as it was
- The recreated model will make the same predictions as the original
Think of it like taking a snapshot of your model's brain at a specific moment in time.
ShouldEarlyStop()
Determines whether the optimization process should stop early.
public virtual bool ShouldEarlyStop()
Returns
- bool
True if the optimization process should stop early; otherwise, false.
Remarks
Early stopping is a technique to prevent overfitting by stopping the optimization process before it completes all iterations if certain conditions are met.
For Beginners: This is like knowing when to stop cooking - if the model is "done" (trained well enough), this method says "stop now" instead of continuing unnecessarily.
Common reasons for early stopping include:
- The model's performance isn't improving anymore
- The model's performance on validation data is getting worse (overfitting)
- The changes in parameters are becoming very small (convergence)
Early stopping helps:
- Save computation time
- Prevent the model from becoming too specialized to the training data
- Produce models that generalize better to new data
SynchronizeOptimizerState()
Synchronizes optimizer state (like momentum buffers) across all processes.
public abstract void SynchronizeOptimizerState()
Remarks
For Beginners: Some optimizers (like Adam) keep track of past gradients to make smarter updates. This method makes sure all processes have the same optimizer state, so they stay coordinated. It's like making sure all team members are reading from the same playbook.
SynchronizeParameters(IFullModel<T, TInput, TOutput>?)
Synchronizes model parameters across all processes using AllReduce with averaging.
protected virtual void SynchronizeParameters(IFullModel<T, TInput, TOutput>? model)
Parameters
modelIFullModel<T, TInput, TOutput>The model whose parameters to synchronize
Remarks
This method averages parameters across all processes, ensuring consistency. It's called after optimization steps to keep all processes synchronized.
For Beginners: After each process updates its model, we need to make sure everyone has the same parameters.
This method averages the parameters from all processes. For example, if GPU 0 calculated parameter value 1.0 and GPU 1 calculated 1.2, after sync both will have 1.1 (the average).