Class ElasticOptimizer<T, TInput, TOutput>
- Namespace
- AiDotNet.DistributedTraining
- Assembly
- AiDotNet.dll
Implements Elastic optimizer - supports dynamic worker addition/removal during training.
public class ElasticOptimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>, IShardedOptimizer<T, TInput, TOutput>, IOptimizer<T, TInput, TOutput>, IModelSerializer
Type Parameters
TThe numeric type
TInputThe input type for the model
TOutputThe output type for the model
- Inheritance
-
ShardedOptimizerBase<T, TInput, TOutput>ElasticOptimizer<T, TInput, TOutput>
- Implements
-
IShardedOptimizer<T, TInput, TOutput>IOptimizer<T, TInput, TOutput>
- Inherited Members
- Extension Methods
Remarks
Strategy Overview: Elastic training (TorchElastic, Horovod Elastic) enables dynamic scaling of workers during training. Workers can be added or removed without stopping the training job, supporting: - Fault tolerance: Replace failed workers automatically - Auto-scaling: Add workers during peak hours, remove during off-peak - Spot instance usage: Tolerate preemptions, use cheaper compute
When world size changes, the optimizer handles re-sharding parameters and optimizer states across the new worker set. This requires checkpointing and careful state management.
For Beginners: Elastic training is like having a flexible team size. Workers can join or leave during training without stopping everything:
Scenario 1 - Fault tolerance:
- Start with 8 GPUs training your model
- GPU 3 fails → automatically detected
- Training continues with 7 GPUs (parameters redistributed)
- New GPU joins → training scales back to 8 GPUs
Scenario 2 - Cloud cost optimization:
- Use cheap "spot instances" that can be taken away anytime
- When instance is preempted, training continues with remaining workers
- New instance joins when available
This is critical for long training jobs where failures are expected.
Use Cases: - Long training jobs (days/weeks) where failures will occur - Cloud training with spot/preemptible instances (save 60-90% cost) - Auto-scaling based on load or time of day - Fault tolerance for production training pipelines
Trade-offs: - Memory: Must handle dynamic re-sharding - Communication: Overhead during worker changes (re-sharding, sync) - Complexity: Very High - requires membership management, state re-distribution - Convergence: Learning rate scheduling must account for dynamic world size - Best for: Long jobs, cost-sensitive scenarios, production ML pipelines - Limitation: Worker changes create temporary slowdown during re-sharding
Implementation Note: This framework provides elastic optimizer infrastructure. Full production deployment requires: 1. Membership/discovery service (etcd, ZooKeeper, or cloud-native) 2. Automatic checkpointing before worker changes 3. State re-sharding algorithms 4. Rendezvous mechanism for worker coordination This implementation demonstrates the elastic pattern.
Example:
var optimizer = new AdamOptimizer<double, Tensor<double>, Tensor<double>>(model, options);
var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
var config = new ShardingConfiguration<double>(backend);
var elasticOptimizer = new ElasticOptimizer<double, Tensor<double>, Tensor<double>>(
optimizer, config,
minWorkers: 2, // Can run with as few as 2 workers
maxWorkers: 16); // Can scale up to 16 workers
// Training continues through worker changes:
// 4 workers → 3 workers (one fails) → 5 workers (two join) → ...
Constructors
ElasticOptimizer(IOptimizer<T, TInput, TOutput>, IShardingConfiguration<T>, int, int)
Creates an elastic optimizer.
public ElasticOptimizer(IOptimizer<T, TInput, TOutput> wrappedOptimizer, IShardingConfiguration<T> config, int minWorkers = 1, int maxWorkers = 1024)
Parameters
wrappedOptimizerIOptimizer<T, TInput, TOutput>The optimizer to wrap with elastic capabilities
configIShardingConfiguration<T>Configuration for sharding and communication
minWorkersintMinimum number of workers (default: 1)
maxWorkersintMaximum number of workers (default: 1024)
Properties
CanScaleDown
Gets whether the optimizer can tolerate losing workers.
public bool CanScaleDown { get; }
Property Value
CanScaleUp
Gets whether the optimizer can accept more workers.
public bool CanScaleUp { get; }
Property Value
CurrentWorkers
Gets the current number of active workers.
public int CurrentWorkers { get; }
Property Value
Methods
Deserialize(byte[])
Loads a previously serialized model from binary data.
public override void Deserialize(byte[] data)
Parameters
databyte[]The byte array containing the serialized model data.
Remarks
This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.
For Beginners: This is like opening a saved file to continue your work.
When you call this method:
- You provide the binary data (bytes) that was previously created by Serialize
- The model rebuilds itself using this data
- After deserializing, the model is exactly as it was when serialized
- It's ready to make predictions without needing to be trained again
For example:
- You download a pre-trained model file for detecting spam emails
- You deserialize this file into your application
- Immediately, your application can detect spam without any training
- The model has all the knowledge that was built into it by its original creator
This is particularly useful when:
- You want to use a model that took days to train
- You need to deploy the same model across multiple devices
- You're creating an application that non-technical users will use
Think of it like installing the brain of a trained expert directly into your application.
Optimize(OptimizationInputData<T, TInput, TOutput>)
Performs the optimization process to find the best parameters for a model.
public override OptimizationResult<T, TInput, TOutput> Optimize(OptimizationInputData<T, TInput, TOutput> inputData)
Parameters
inputDataOptimizationInputData<T, TInput, TOutput>The data needed for optimization, including the objective function, initial parameters, and any constraints.
Returns
- OptimizationResult<T, TInput, TOutput>
The result of the optimization process, including the optimized parameters and performance metrics.
Remarks
This method takes input data and attempts to find the optimal parameters that minimize or maximize the objective function.
For Beginners: This is where the actual "learning" happens. The optimizer looks at your data and tries different parameter values to find the ones that make your model perform best.
The process typically involves:
- Evaluating how well the current parameters perform
- Calculating how to change the parameters to improve performance
- Updating the parameters
- Repeating until the model performs well enough or reaches a maximum number of attempts
Serialize()
Converts the current state of a machine learning model into a binary format.
public override byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model data.
Remarks
This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.
For Beginners: This is like exporting your work to a file.
When you call this method:
- The model's current state (all its learned patterns and parameters) is captured
- This information is converted into a compact binary format (bytes)
- You can then save these bytes to a file, database, or send them over a network
For example:
- After training a model to recognize cats vs. dogs in images
- You can serialize the model to save all its learned knowledge
- Later, you can use this saved data to recreate the model exactly as it was
- The recreated model will make the same predictions as the original
Think of it like taking a snapshot of your model's brain at a specific moment in time.
SynchronizeOptimizerState()
Synchronizes optimizer state (like momentum buffers) across all processes.
public override void SynchronizeOptimizerState()
Remarks
For Beginners: Some optimizers (like Adam) keep track of past gradients to make smarter updates. This method makes sure all processes have the same optimizer state, so they stay coordinated. It's like making sure all team members are reading from the same playbook.