Class ZeRO3Model<T, TInput, TOutput>
- Namespace
- AiDotNet.DistributedTraining
- Assembly
- AiDotNet.dll
Implements ZeRO Stage 3 model wrapper - full sharding of parameters, gradients, and optimizer states.
public class ZeRO3Model<T, TInput, TOutput> : FSDPModel<T, TInput, TOutput>, IShardedModel<T, TInput, TOutput>, IFullModel<T, TInput, TOutput>, IModel<TInput, TOutput, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, TInput, TOutput>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, TInput, TOutput>>, IGradientComputable<T, TInput, TOutput>, IJitCompilable<T>
Type Parameters
TThe numeric type
TInputThe input type for the model
TOutputThe output type for the model
- Inheritance
-
ShardedModelBase<T, TInput, TOutput>FSDPModel<T, TInput, TOutput>ZeRO3Model<T, TInput, TOutput>
- Implements
-
IShardedModel<T, TInput, TOutput>IFullModel<T, TInput, TOutput>IModel<TInput, TOutput, ModelMetadata<T>>IParameterizable<T, TInput, TOutput>ICloneable<IFullModel<T, TInput, TOutput>>IGradientComputable<T, TInput, TOutput>
- Inherited Members
- Extension Methods
Remarks
Strategy Overview: ZeRO Stage 3 is the full implementation of the ZeRO optimization, sharding parameters, gradients, AND optimizer states across all processes. This is equivalent to PyTorch's FSDP (Fully Sharded Data Parallel). Parameters are gathered just-in-time for forward/backward passes and immediately released, maximizing memory efficiency.
For Beginners: ZeRO-3 is identical to FSDP - it's the ultimate memory-saving strategy. Everything is sharded: parameters, gradients, and optimizer states. Each process only holds a small piece of the model, and pieces are gathered only when absolutely needed, then immediately released.
This class is essentially an alias/wrapper for FSDPModel to maintain ZeRO naming consistency.
Use Cases: - Same as FSDP - training very large models - When you prefer ZeRO terminology over FSDP - Maximum memory efficiency
Trade-offs: - Same as FSDP - Memory: Excellent - everything sharded - Communication: Higher - AllGather for each forward/backward - Complexity: Moderate
Example:
var model = new NeuralNetworkModel<double>(...);
var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
var config = new ShardingConfiguration<double>(backend);
// ZeRO-3 and FSDP are equivalent
var zero3Model = new ZeRO3Model<double, Tensor<double>, Tensor<double>>(model, config);
// Or equivalently:
// var fsdpModel = new FSDPModel<double, Tensor<double>, Tensor<double>>(model, config);
Constructors
ZeRO3Model(IFullModel<T, TInput, TOutput>, IShardingConfiguration<T>)
Creates a new ZeRO-3 model wrapping an existing model.
public ZeRO3Model(IFullModel<T, TInput, TOutput> wrappedModel, IShardingConfiguration<T> config)
Parameters
wrappedModelIFullModel<T, TInput, TOutput>The model to wrap with ZeRO-3 capabilities
configIShardingConfiguration<T>Configuration for sharding and communication
Remarks
For Beginners: ZeRO-3 is the same as FSDP, just different terminology. Use whichever name you prefer. This constructor delegates to FSDPModel for all functionality.
Exceptions
- ArgumentNullException
Thrown if model or config is null
Methods
Clone()
Creates a shallow copy of this object.
public override IFullModel<T, TInput, TOutput> Clone()
Returns
- IFullModel<T, TInput, TOutput>
GetModelMetadata()
Retrieves metadata and performance metrics about the trained model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
An object containing metadata and performance metrics about the trained model.
Remarks
This method provides information about the model's structure, parameters, and performance metrics.
For Beginners: Model metadata is like a report card for your machine learning model.
Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.
This information typically includes:
- Accuracy measures: How well does the model's predictions match actual values?
- Error metrics: How far off are the model's predictions on average?
- Model parameters: What patterns did the model learn from the data?
- Training information: How long did training take? How many iterations were needed?
For example, in a house price prediction model, metadata might include:
- Average prediction error (e.g., off by $15,000 on average)
- How strongly each feature (bedrooms, location) influences the prediction
- How well the model fits the training data
This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.
WithParameters(Vector<T>)
Creates a new instance with the specified parameters.
public override IFullModel<T, TInput, TOutput> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>
Returns
- IFullModel<T, TInput, TOutput>