Interface IShardedModel<T, TInput, TOutput>
- Namespace
- AiDotNet.DistributedTraining
- Assembly
- AiDotNet.dll
Defines the contract for models that support distributed training with parameter sharding.
public interface IShardedModel<T, TInput, TOutput> : IFullModel<T, TInput, TOutput>, IModel<TInput, TOutput, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, TInput, TOutput>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, TInput, TOutput>>, IGradientComputable<T, TInput, TOutput>, IJitCompilable<T>
Type Parameters
TThe numeric type for operations
TInputThe input type for the model
TOutputThe output type for the model
- Inherited Members
- Extension Methods
Remarks
For Beginners: A sharded model is like having a team working on a large puzzle together. Instead of one person holding all the puzzle pieces (parameters), each person holds only a portion. When someone needs to see the full picture, everyone shares their pieces (AllGather). When the team learns something new, everyone combines their learnings (AllReduce).
This allows training models that are too large to fit on a single GPU or machine.
Properties
LocalParameterShard
Gets the portion of parameters owned by this process.
Vector<T> LocalParameterShard { get; }
Property Value
- Vector<T>
Remarks
For Beginners: This is "your piece of the puzzle" - the parameters that this particular process is responsible for storing and updating.
Rank
Gets the rank of this process in the distributed group.
int Rank { get; }
Property Value
Remarks
For Beginners: Each process has a unique ID (rank). This tells you which process you are. Rank 0 is typically the "coordinator" process.
ShardingConfiguration
Gets the configuration for this sharded model.
IShardingConfiguration<T> ShardingConfiguration { get; }
Property Value
WorldSize
Gets the total number of processes in the distributed group.
int WorldSize { get; }
Property Value
Remarks
For Beginners: This is how many processes are working together to train the model. For example, if you have 4 GPUs, WorldSize would be 4.
WrappedModel
Gets the underlying wrapped model.
IFullModel<T, TInput, TOutput> WrappedModel { get; }
Property Value
- IFullModel<T, TInput, TOutput>
Remarks
For Beginners: This is the original model that we're adding distributed training capabilities to. Think of it as the "core brain" that we're helping to work in a distributed way.
Methods
GatherFullParameters()
Gets the full set of parameters by gathering from all processes.
Vector<T> GatherFullParameters()
Returns
- Vector<T>
The complete set of parameters gathered from all processes
Remarks
This operation involves communication across all processes.
For Beginners: This is like asking everyone to share their puzzle pieces so you can see the complete picture. It requires communication between all processes, so it's more expensive than just accessing LocalParameterShard.
SynchronizeGradients()
Synchronizes gradients across all processes using AllReduce.
void SynchronizeGradients()
Remarks
After this operation, all processes have the same averaged gradients.
For Beginners: During training, each process calculates gradients based on its portion of the data. This method combines (averages) those gradients so that everyone is learning from everyone else's experiences. It's like a team meeting where everyone shares what they learned.