Class ZeRO2Model<T, TInput, TOutput>
- Namespace
- AiDotNet.DistributedTraining
- Assembly
- AiDotNet.dll
Implements ZeRO Stage 2 model wrapper - shards optimizer states and gradients.
public class ZeRO2Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>, IShardedModel<T, TInput, TOutput>, IFullModel<T, TInput, TOutput>, IModel<TInput, TOutput, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, TInput, TOutput>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, TInput, TOutput>>, IGradientComputable<T, TInput, TOutput>, IJitCompilable<T>
Type Parameters
TThe numeric type
TInputThe input type for the model
TOutputThe output type for the model
- Inheritance
-
ShardedModelBase<T, TInput, TOutput>ZeRO2Model<T, TInput, TOutput>
- Implements
-
IShardedModel<T, TInput, TOutput>IFullModel<T, TInput, TOutput>IModel<TInput, TOutput, ModelMetadata<T>>IParameterizable<T, TInput, TOutput>ICloneable<IFullModel<T, TInput, TOutput>>IGradientComputable<T, TInput, TOutput>
- Inherited Members
- Extension Methods
Remarks
Strategy Overview: ZeRO Stage 2 builds on ZeRO-1 by additionally sharding gradients across processes. Parameters are still replicated for the forward pass, but gradients are reduced and scattered (ReduceScatter) so each process only stores a portion. This saves significant memory compared to ZeRO-1, especially for large models.
For Beginners: This implements ZeRO Stage 2, which saves even more memory than ZeRO-1. The model parameters are still fully replicated (like DDP and ZeRO-1), but now both the optimizer state AND the gradients are split across processes. After computing gradients, they're immediately reduced and scattered so each process only keeps its portion.
Think of it like a team where everyone has the full playbook (parameters), but when taking notes during practice (gradients), they divide up the note-taking so each person is responsible for recording only certain plays. This saves everyone from having to write everything down.
Use Cases: - Larger models where gradient memory becomes significant - Want substantial memory savings with moderate communication cost - Preparing for ZeRO-3/FSDP migration
Trade-offs: - Memory: Very Good - saves both optimizer states and gradients - Communication: Moderate - uses ReduceScatter instead of AllReduce - Complexity: Moderate - gradient sharding adds some complexity - Best for: Large models where gradient memory is significant
Example:
var model = new NeuralNetworkModel<double>(...);
var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
var config = new ShardingConfiguration<double>(backend);
var zero2Model = new ZeRO2Model<double, Tensor<double>, Tensor<double>>(model, config);
// Use with ZeRO2Optimizer for full ZeRO-2 benefits
var zero2Optimizer = new ZeRO2Optimizer<double, Tensor<double>, Tensor<double>>(optimizer, config);
Constructors
ZeRO2Model(IFullModel<T, TInput, TOutput>, IShardingConfiguration<T>)
public ZeRO2Model(IFullModel<T, TInput, TOutput> wrappedModel, IShardingConfiguration<T> config)
Parameters
wrappedModelIFullModel<T, TInput, TOutput>configIShardingConfiguration<T>
Properties
ParameterDeltaShard
Gets the local parameter delta shard for this rank after synchronization.
public Vector<T>? ParameterDeltaShard { get; }
Property Value
- Vector<T>
Remarks
DEPRECATED: This property is no longer used. ZeRO2Model now uses true gradient semantics via IFullModel.ComputeGradients() which properly separates gradient computation from parameter updates. The implementation stores true gradients (not parameter deltas) in the internal _gradientShard field.
In the current implementation: 1. ComputeGradients() computes true gradients via backpropagation without modifying parameters 2. Gradients are sharded via ReduceScatter so each rank stores only its portion 3. Each rank updates only its parameter shard using the gradient shard 4. All ranks perform AllGather to reconstruct the full updated parameter vector
This property always returns null in the current implementation. For gradient access, use the Train() workflow which internally manages _gradientShard.
Methods
AllGatherParameterShards()
Reconstructs full parameters by gathering parameter shards from all ranks.
public Vector<T> AllGatherParameterShards()
Returns
- Vector<T>
Full parameter vector reconstructed from all ranks' shards
Remarks
In ZeRO-2, after the optimizer updates each rank's parameter shard, we need to AllGather all shards to reconstruct the full parameter vector for the next forward pass. This ensures all ranks have identical synchronized parameters.
This method is primarily used when integrating with ZeRO2Optimizer, where each rank updates only its portion of the parameter vector. Proper AllGather collects these disjoint updated shards and concatenates them to form the complete parameter vector.
Clone()
Creates a shallow copy of this object.
public override IFullModel<T, TInput, TOutput> Clone()
Returns
- IFullModel<T, TInput, TOutput>
Deserialize(byte[])
Loads a previously serialized model from binary data.
public override void Deserialize(byte[] data)
Parameters
databyte[]The byte array containing the serialized model data.
Remarks
This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.
For Beginners: This is like opening a saved file to continue your work.
When you call this method:
- You provide the binary data (bytes) that was previously created by Serialize
- The model rebuilds itself using this data
- After deserializing, the model is exactly as it was when serialized
- It's ready to make predictions without needing to be trained again
For example:
- You download a pre-trained model file for detecting spam emails
- You deserialize this file into your application
- Immediately, your application can detect spam without any training
- The model has all the knowledge that was built into it by its original creator
This is particularly useful when:
- You want to use a model that took days to train
- You need to deploy the same model across multiple devices
- You're creating an application that non-technical users will use
Think of it like installing the brain of a trained expert directly into your application.
GetModelMetadata()
Retrieves metadata and performance metrics about the trained model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
An object containing metadata and performance metrics about the trained model.
Remarks
This method provides information about the model's structure, parameters, and performance metrics.
For Beginners: Model metadata is like a report card for your machine learning model.
Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.
This information typically includes:
- Accuracy measures: How well does the model's predictions match actual values?
- Error metrics: How far off are the model's predictions on average?
- Model parameters: What patterns did the model learn from the data?
- Training information: How long did training take? How many iterations were needed?
For example, in a house price prediction model, metadata might include:
- Average prediction error (e.g., off by $15,000 on average)
- How strongly each feature (bedrooms, location) influences the prediction
- How well the model fits the training data
This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.
InitializeSharding()
Initializes parameter sharding by dividing parameters across processes.
protected override void InitializeSharding()
Remarks
This method calculates how to distribute parameters evenly across all processes, with remainder parameters distributed to the first few processes. Derived classes can override this to implement different sharding strategies.
For Beginners: This splits the model's parameters across all processes.
Think of it like dividing a deck of cards among players. If you have 10 parameters and 3 processes:
- Process 0 gets parameters 0-3 (4 parameters)
- Process 1 gets parameters 4-6 (3 parameters)
- Process 2 gets parameters 7-9 (3 parameters)
We try to split evenly, but if there's a remainder, the first processes get one extra parameter each.
LoadModel(string)
Loads the model from a file.
public override void LoadModel(string filePath)
Parameters
filePathstringThe path to the file containing the saved model.
Remarks
This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.
For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.
Exceptions
- FileNotFoundException
Thrown when the specified file does not exist.
- IOException
Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.
Predict(TInput)
Uses the trained model to make predictions for new input data.
public override TOutput Predict(TInput input)
Parameters
inputTInputA matrix where each row represents a new example to predict and each column represents a feature.
Returns
- TOutput
A vector containing the predicted values for each input example.
Remarks
After training, this method applies the learned patterns to new data to predict outcomes.
For Beginners: Prediction is when the model uses what it learned to make educated guesses about new information.
Continuing the fruit identification example:
- After learning from many examples, the child (model) can now identify new fruits they haven't seen before
- They look at the color, shape, and size to make their best guess
In machine learning:
- You give the model new data it hasn't seen during training
- The model applies the patterns it learned to make predictions
- The output is the model's best estimate based on its training
For example, in a house price prediction model:
- You provide features of a new house (square footage, bedrooms, location)
- The model predicts what price that house might sell for
This method is used after training is complete, when you want to apply your model to real-world data.
SaveModel(string)
Saves the model to a file.
public override void SaveModel(string filePath)
Parameters
filePathstringThe path where the model should be saved.
Remarks
This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.
For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.
Exceptions
- IOException
Thrown when an I/O error occurs while writing to the file.
- UnauthorizedAccessException
Thrown when the caller does not have the required permission to write to the specified file path.
Serialize()
Converts the current state of a machine learning model into a binary format.
public override byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model data.
Remarks
This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.
For Beginners: This is like exporting your work to a file.
When you call this method:
- The model's current state (all its learned patterns and parameters) is captured
- This information is converted into a compact binary format (bytes)
- You can then save these bytes to a file, database, or send them over a network
For example:
- After training a model to recognize cats vs. dogs in images
- You can serialize the model to save all its learned knowledge
- Later, you can use this saved data to recreate the model exactly as it was
- The recreated model will make the same predictions as the original
Think of it like taking a snapshot of your model's brain at a specific moment in time.
SynchronizeGradients()
Synchronizes gradients using ReduceScatter - each process gets its shard of reduced gradients.
public override void SynchronizeGradients()
Train(TInput, TOutput)
Trains the model using input features and their corresponding target values.
public override void Train(TInput input, TOutput expectedOutput)
Parameters
inputTInputexpectedOutputTOutput
Remarks
This method takes training data and adjusts the model's internal parameters to learn patterns in the data.
For Beginners: Training is like teaching the model by showing it examples.
Imagine teaching a child to identify fruits:
- You show them many examples of apples, oranges, and bananas (input features x)
- You tell them the correct name for each fruit (target values y)
- Over time, they learn to recognize the patterns that distinguish each fruit
In machine learning:
- The x parameter contains features (characteristics) of your data
- The y parameter contains the correct answers you want the model to learn
- During training, the model adjusts its internal calculations to get better at predicting y from x
For example, in a house price prediction model:
- x would contain features like square footage, number of bedrooms, location
- y would contain the actual sale prices of those houses
WithParameters(Vector<T>)
Creates a new instance with the specified parameters.
public override IFullModel<T, TInput, TOutput> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>
Returns
- IFullModel<T, TInput, TOutput>