This document outlines all concrete implementations that should be created for the distributed training framework, based on industry standards and real-world scenarios.
ICommunicationBackend<T>
↓
CommunicationBackendBase<T> (abstract)
↓
├── InMemoryCommunicationBackend<T> (for testing)
├── MPICommunicationBackend<T> (MPI.NET for production)
├── NCCLCommunicationBackend<T> (NVIDIA GPUs)
└── GlooCommunicationBackend<T> (CPU-based)
IShardedModel<T, TInput, TOutput>
↓
ShardedModelBase<T, TInput, TOutput> (abstract)
↓
├── FSDPModel<T, TInput, TOutput> (Fully Sharded Data Parallel - PyTorch style)
├── ZeRO1Model<T, TInput, TOutput> (ZeRO Stage 1 - optimizer state sharding only)
├── ZeRO2Model<T, TInput, TOutput> (ZeRO Stage 2 - optimizer + gradient sharding)
├── ZeRO3Model<T, TInput, TOutput> (ZeRO Stage 3 - full parameter sharding)
├── DDPModel<T, TInput, TOutput> (Distributed Data Parallel - parameter replication)
├── PipelineParallelModel<T, TInput, TOutput> (GPipe-style pipeline parallelism)
├── TensorParallelModel<T, TInput, TOutput> (Megatron-LM style tensor parallelism)
└── HybridShardedModel<T, TInput, TOutput> (3D parallelism: data + tensor + pipeline)
IShardedOptimizer<T, TInput, TOutput>
↓
ShardedOptimizerBase<T, TInput, TOutput> (abstract)
↓
├── ZeRO1Optimizer<T, TInput, TOutput> (Shards optimizer state only)
├── ZeRO2Optimizer<T, TInput, TOutput> (Shards optimizer state + gradients)
├── ZeRO3Optimizer<T, TInput, TOutput> (Full sharding with parameter partitioning)
├── DDPOptimizer<T, TInput, TOutput> (Standard data parallel - AllReduce gradients)
├── GradientCompressionOptimizer<T, TInput, TOutput> (Compressed gradient communication)
├── AsyncSGDOptimizer<T, TInput, TOutput> (Asynchronous parameter updates)
└── ElasticOptimizer<T, TInput, TOutput> (Supports dynamic scaling of workers)
Status: ✅ Currently implemented as ShardedModel
Description: PyTorch FSDP-inspired implementation that shards model parameters, gradients, and optimizer states across all processes.
Key Features:
Use Case: Training models that don’t fit on a single GPU (e.g., LLMs with 7B+ parameters)
Status: ❌ To be implemented
Description: DeepSpeed ZeRO Stage 1 - only shards optimizer states, keeps parameters and gradients replicated.
Key Features:
Use Case: Medium-sized models where optimizer state is the memory bottleneck (e.g., Adam with 2x model size overhead)
Implementation Notes:
public class ZeRO1Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
// Keep full parameters locally
private Vector<T> _fullParameters;
protected override void InitializeSharding()
{
// Don't shard parameters, keep full copy
_fullParameters = WrappedModel.GetParameters();
LocalShard = _fullParameters; // No actual sharding
}
public override void SynchronizeGradients()
{
// Standard AllReduce for gradient averaging
// Optimizer state sharding handled by ZeRO1Optimizer
}
}
Status: ❌ To be implemented
Description: DeepSpeed ZeRO Stage 2 - shards optimizer states AND gradients, keeps parameters replicated.
Key Features:
Use Case: Large models where gradient + optimizer memory is significant (e.g., models with 1B-10B parameters)
Implementation Notes:
public class ZeRO2Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
private Dictionary<int, Vector<T>> _shardedGradients;
public override void SynchronizeGradients()
{
// Use ReduceScatter to shard gradients across ranks
// Each rank only keeps its shard of gradients
var fullGradients = GetGradients();
LocalShard = Config.CommunicationBackend.ReduceScatter(
fullGradients,
ReductionOperation.Average);
}
}
Status: ❌ To be implemented (similar to current FSDP)
Description: DeepSpeed ZeRO Stage 3 - full sharding of parameters, gradients, and optimizer states.
Key Features:
Use Case: Extremely large models (10B-175B+ parameters) that require multi-GPU/multi-node training
Status: ❌ To be implemented
Description: Traditional DDP like PyTorch DDP - parameters replicated, gradients synchronized.
Key Features:
Use Case: Training medium-sized models (< 1B parameters) across multiple GPUs for faster training
Implementation Notes:
public class DDPModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
protected override void InitializeSharding()
{
// No sharding - each rank has full parameters
var fullParams = WrappedModel.GetParameters();
LocalShard = fullParams;
CachedFullParameters = fullParams;
}
public override Vector<T> GatherFullParameters()
{
// Already have full parameters, no gather needed
return LocalShard;
}
public override void SynchronizeGradients()
{
// AllReduce gradients to average across all ranks
var gradients = GetGradients();
Config.CommunicationBackend.AllReduce(gradients, ReductionOperation.Average);
SetGradients(gradients);
}
}
Status: ❌ To be implemented
Description: GPipe-style pipeline parallelism - splits model into stages across ranks.
Key Features:
Use Case: Very deep models (transformers with 100+ layers) or when model architecture is easily divisible
Implementation Notes:
public class PipelineParallelModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
private int _pipelineStage;
private IFullModel<T, TInput, TOutput>[] _stageModels;
public override void Train(TInput input, TOutput expectedOutput)
{
// Forward pass: send activations to next stage
// Backward pass: send gradients to previous stage
// Use micro-batching to overlap computation
}
}
Status: ❌ To be implemented
Description: Megatron-LM style tensor parallelism - splits individual layers across ranks.
Key Features:
Use Case: Very wide models (large transformers with huge hidden dimensions) or when activation memory is the bottleneck
Status: ❌ To be implemented
Description: Combines data parallelism, tensor parallelism, and pipeline parallelism.
Key Features:
Use Case: Training models with 100B-1T+ parameters across hundreds/thousands of GPUs
Status: ❌ To be implemented
Description: Shards optimizer states (momentum, variance buffers) across ranks.
Key Features:
Implementation Notes:
public class ZeRO1Optimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>
{
private Dictionary<string, Vector<T>> _shardedOptimizerStates;
protected override void UpdateOptimizerState(Vector<T> gradients)
{
// Only update my shard of optimizer state
// AllGather when needed for full parameter update
}
}
Status: ❌ To be implemented
Description: Shards both gradients and optimizer states.
Key Features:
Status: ✅ Currently implemented as ShardedOptimizer
Description: Full parameter, gradient, and optimizer state sharding.
Status: ❌ To be implemented
Description: Standard AllReduce-based gradient synchronization.
Key Features:
Status: ❌ To be implemented
Description: Compresses gradients before communication.
Key Features:
Implementation Notes:
public class GradientCompressionOptimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>
{
private IGradientCompressor<T> _compressor;
protected override void SynchronizeParameters(IFullModel<T, TInput, TOutput> model)
{
var gradients = model.GetGradients();
var compressed = _compressor.Compress(gradients);
Config.CommunicationBackend.AllReduce(compressed, ReductionOperation.Sum);
var decompressed = _compressor.Decompress(compressed);
model.SetGradients(decompressed);
}
}
Status: ❌ To be implemented
Description: Asynchronous parameter updates without strict synchronization.
Key Features:
Status: ❌ To be implemented
Description: Supports dynamic addition/removal of workers during training.
Key Features:
Status: ✅ Implemented
Use Case: Testing and development without MPI
Status: ❌ To be implemented
Description: Production MPI.NET backend for CPU/GPU clusters.
Key Features:
Status: ❌ To be implemented
Description: NVIDIA NCCL backend for GPU-to-GPU communication.
Key Features:
Status: ❌ To be implemented
Description: Facebook Gloo backend for CPU clusters.
Key Features:
InitializeSharding() - How to shard/replicate parametersTrain() - Forward/backward with appropriate syncGatherFullParameters() - How to reconstruct full parametersSynchronizeGradients() - Gradient communication patternSerialize()/Deserialize() - Save/load with strategy metadata[Strategy]Model<T, TInput, TOutput>Optimize() - Coordinate distributed optimizationSynchronizeOptimizerState() - Sync momentum/variance buffersSynchronizeParameters() - Gradient/parameter communicationShouldEarlyStop() - Consensus across ranks[Strategy]Optimizer<T, TInput, TOutput>For each implementation:
For each implementation: