AiDotNet

Distributed Training Concrete Implementations

This document outlines all concrete implementations that should be created for the distributed training framework, based on industry standards and real-world scenarios.

Architecture Overview

ICommunicationBackend<T>
    ↓
CommunicationBackendBase<T> (abstract)
    ↓
├── InMemoryCommunicationBackend<T> (for testing)
├── MPICommunicationBackend<T> (MPI.NET for production)
├── NCCLCommunicationBackend<T> (NVIDIA GPUs)
└── GlooCommunicationBackend<T> (CPU-based)

IShardedModel<T, TInput, TOutput>
    ↓
ShardedModelBase<T, TInput, TOutput> (abstract)
    ↓
├── FSDPModel<T, TInput, TOutput> (Fully Sharded Data Parallel - PyTorch style)
├── ZeRO1Model<T, TInput, TOutput> (ZeRO Stage 1 - optimizer state sharding only)
├── ZeRO2Model<T, TInput, TOutput> (ZeRO Stage 2 - optimizer + gradient sharding)
├── ZeRO3Model<T, TInput, TOutput> (ZeRO Stage 3 - full parameter sharding)
├── DDPModel<T, TInput, TOutput> (Distributed Data Parallel - parameter replication)
├── PipelineParallelModel<T, TInput, TOutput> (GPipe-style pipeline parallelism)
├── TensorParallelModel<T, TInput, TOutput> (Megatron-LM style tensor parallelism)
└── HybridShardedModel<T, TInput, TOutput> (3D parallelism: data + tensor + pipeline)

IShardedOptimizer<T, TInput, TOutput>
    ↓
ShardedOptimizerBase<T, TInput, TOutput> (abstract)
    ↓
├── ZeRO1Optimizer<T, TInput, TOutput> (Shards optimizer state only)
├── ZeRO2Optimizer<T, TInput, TOutput> (Shards optimizer state + gradients)
├── ZeRO3Optimizer<T, TInput, TOutput> (Full sharding with parameter partitioning)
├── DDPOptimizer<T, TInput, TOutput> (Standard data parallel - AllReduce gradients)
├── GradientCompressionOptimizer<T, TInput, TOutput> (Compressed gradient communication)
├── AsyncSGDOptimizer<T, TInput, TOutput> (Asynchronous parameter updates)
└── ElasticOptimizer<T, TInput, TOutput> (Supports dynamic scaling of workers)

Model Implementations

1. FSDPModel<T, TInput, TOutput> - Fully Sharded Data Parallel

Status: ✅ Currently implemented as ShardedModel

Description: PyTorch FSDP-inspired implementation that shards model parameters, gradients, and optimizer states across all processes.

Key Features:

Use Case: Training models that don’t fit on a single GPU (e.g., LLMs with 7B+ parameters)


2. ZeRO1Model<T, TInput, TOutput> - ZeRO Stage 1

Status: ❌ To be implemented

Description: DeepSpeed ZeRO Stage 1 - only shards optimizer states, keeps parameters and gradients replicated.

Key Features:

Use Case: Medium-sized models where optimizer state is the memory bottleneck (e.g., Adam with 2x model size overhead)

Implementation Notes:

public class ZeRO1Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
    // Keep full parameters locally
    private Vector<T> _fullParameters;

    protected override void InitializeSharding()
    {
        // Don't shard parameters, keep full copy
        _fullParameters = WrappedModel.GetParameters();
        LocalShard = _fullParameters; // No actual sharding
    }

    public override void SynchronizeGradients()
    {
        // Standard AllReduce for gradient averaging
        // Optimizer state sharding handled by ZeRO1Optimizer
    }
}

3. ZeRO2Model<T, TInput, TOutput> - ZeRO Stage 2

Status: ❌ To be implemented

Description: DeepSpeed ZeRO Stage 2 - shards optimizer states AND gradients, keeps parameters replicated.

Key Features:

Use Case: Large models where gradient + optimizer memory is significant (e.g., models with 1B-10B parameters)

Implementation Notes:

public class ZeRO2Model<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
    private Dictionary<int, Vector<T>> _shardedGradients;

    public override void SynchronizeGradients()
    {
        // Use ReduceScatter to shard gradients across ranks
        // Each rank only keeps its shard of gradients
        var fullGradients = GetGradients();
        LocalShard = Config.CommunicationBackend.ReduceScatter(
            fullGradients,
            ReductionOperation.Average);
    }
}

4. ZeRO3Model<T, TInput, TOutput> - ZeRO Stage 3

Status: ❌ To be implemented (similar to current FSDP)

Description: DeepSpeed ZeRO Stage 3 - full sharding of parameters, gradients, and optimizer states.

Key Features:

Use Case: Extremely large models (10B-175B+ parameters) that require multi-GPU/multi-node training


5. DDPModel<T, TInput, TOutput> - Distributed Data Parallel

Status: ❌ To be implemented

Description: Traditional DDP like PyTorch DDP - parameters replicated, gradients synchronized.

Key Features:

Use Case: Training medium-sized models (< 1B parameters) across multiple GPUs for faster training

Implementation Notes:

public class DDPModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
    protected override void InitializeSharding()
    {
        // No sharding - each rank has full parameters
        var fullParams = WrappedModel.GetParameters();
        LocalShard = fullParams;
        CachedFullParameters = fullParams;
    }

    public override Vector<T> GatherFullParameters()
    {
        // Already have full parameters, no gather needed
        return LocalShard;
    }

    public override void SynchronizeGradients()
    {
        // AllReduce gradients to average across all ranks
        var gradients = GetGradients();
        Config.CommunicationBackend.AllReduce(gradients, ReductionOperation.Average);
        SetGradients(gradients);
    }
}

6. PipelineParallelModel<T, TInput, TOutput> - Pipeline Parallelism

Status: ❌ To be implemented

Description: GPipe-style pipeline parallelism - splits model into stages across ranks.

Key Features:

Use Case: Very deep models (transformers with 100+ layers) or when model architecture is easily divisible

Implementation Notes:

public class PipelineParallelModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>
{
    private int _pipelineStage;
    private IFullModel<T, TInput, TOutput>[] _stageModels;

    public override void Train(TInput input, TOutput expectedOutput)
    {
        // Forward pass: send activations to next stage
        // Backward pass: send gradients to previous stage
        // Use micro-batching to overlap computation
    }
}

7. TensorParallelModel<T, TInput, TOutput> - Tensor Parallelism

Status: ❌ To be implemented

Description: Megatron-LM style tensor parallelism - splits individual layers across ranks.

Key Features:

Use Case: Very wide models (large transformers with huge hidden dimensions) or when activation memory is the bottleneck


8. HybridShardedModel<T, TInput, TOutput> - 3D Parallelism

Status: ❌ To be implemented

Description: Combines data parallelism, tensor parallelism, and pipeline parallelism.

Key Features:

Use Case: Training models with 100B-1T+ parameters across hundreds/thousands of GPUs


Optimizer Implementations

1. ZeRO1Optimizer<T, TInput, TOutput> - Optimizer State Sharding

Status: ❌ To be implemented

Description: Shards optimizer states (momentum, variance buffers) across ranks.

Key Features:

Implementation Notes:

public class ZeRO1Optimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>
{
    private Dictionary<string, Vector<T>> _shardedOptimizerStates;

    protected override void UpdateOptimizerState(Vector<T> gradients)
    {
        // Only update my shard of optimizer state
        // AllGather when needed for full parameter update
    }
}

2. ZeRO2Optimizer<T, TInput, TOutput> - Gradient + State Sharding

Status: ❌ To be implemented

Description: Shards both gradients and optimizer states.

Key Features:


3. ZeRO3Optimizer<T, TInput, TOutput> - Full Sharding

Status: ✅ Currently implemented as ShardedOptimizer

Description: Full parameter, gradient, and optimizer state sharding.


4. DDPOptimizer<T, TInput, TOutput> - Standard Data Parallel

Status: ❌ To be implemented

Description: Standard AllReduce-based gradient synchronization.

Key Features:


5. GradientCompressionOptimizer<T, TInput, TOutput>

Status: ❌ To be implemented

Description: Compresses gradients before communication.

Key Features:

Implementation Notes:

public class GradientCompressionOptimizer<T, TInput, TOutput> : ShardedOptimizerBase<T, TInput, TOutput>
{
    private IGradientCompressor<T> _compressor;

    protected override void SynchronizeParameters(IFullModel<T, TInput, TOutput> model)
    {
        var gradients = model.GetGradients();
        var compressed = _compressor.Compress(gradients);
        Config.CommunicationBackend.AllReduce(compressed, ReductionOperation.Sum);
        var decompressed = _compressor.Decompress(compressed);
        model.SetGradients(decompressed);
    }
}

6. AsyncSGDOptimizer<T, TInput, TOutput>

Status: ❌ To be implemented

Description: Asynchronous parameter updates without strict synchronization.

Key Features:


7. ElasticOptimizer<T, TInput, TOutput>

Status: ❌ To be implemented

Description: Supports dynamic addition/removal of workers during training.

Key Features:


Communication Backend Implementations

1. **InMemoryCommunicationBackend**

Status: ✅ Implemented

Use Case: Testing and development without MPI


2. **MPICommunicationBackend**

Status: ❌ To be implemented

Description: Production MPI.NET backend for CPU/GPU clusters.

Key Features:


3. **NCCLCommunicationBackend**

Status: ❌ To be implemented

Description: NVIDIA NCCL backend for GPU-to-GPU communication.

Key Features:


4. **GlooCommunicationBackend**

Status: ❌ To be implemented

Description: Facebook Gloo backend for CPU clusters.

Key Features:


Priority Implementation Order

Phase 1: Core DDP (Most Common Use Case)

  1. ✅ InMemoryCommunicationBackend (done)
  2. ❌ DDPModel - Standard data parallel
  3. ❌ DDPOptimizer - AllReduce gradients
  4. ❌ MPICommunicationBackend - Production backend

Phase 2: Memory-Efficient ZeRO

  1. ❌ ZeRO1Model + ZeRO1Optimizer - Optimizer state sharding
  2. ❌ ZeRO2Model + ZeRO2Optimizer - Gradient + state sharding
  3. ✅ ZeRO3 (rename current ShardedModel/Optimizer to FSDPModel/FSDPOptimizer)

Phase 3: Advanced Parallelism

  1. ❌ PipelineParallelModel - Layer-wise parallelism
  2. ❌ TensorParallelModel - Tensor-wise parallelism
  3. ❌ HybridShardedModel - 3D parallelism

Phase 4: Optimizations

  1. ❌ GradientCompressionOptimizer - Reduce communication
  2. ❌ NCCLCommunicationBackend - GPU optimization
  3. ❌ AsyncSGDOptimizer - Async updates
  4. ❌ ElasticOptimizer - Dynamic scaling

Implementation Guidelines

For Each Model Implementation

  1. Inherit from ShardedModelBase<T, TInput, TOutput>
  2. Override required methods:
    • InitializeSharding() - How to shard/replicate parameters
    • Train() - Forward/backward with appropriate sync
    • GatherFullParameters() - How to reconstruct full parameters
    • SynchronizeGradients() - Gradient communication pattern
    • Serialize()/Deserialize() - Save/load with strategy metadata
  3. Follow naming convention: [Strategy]Model<T, TInput, TOutput>
  4. Add comprehensive documentation with use cases and memory/communication trade-offs
  5. Include example usage in XML docs

For Each Optimizer Implementation

  1. Inherit from ShardedOptimizerBase<T, TInput, TOutput>
  2. Override required methods:
    • Optimize() - Coordinate distributed optimization
    • SynchronizeOptimizerState() - Sync momentum/variance buffers
    • SynchronizeParameters() - Gradient/parameter communication
    • ShouldEarlyStop() - Consensus across ranks
  3. Follow naming convention: [Strategy]Optimizer<T, TInput, TOutput>
  4. Match with corresponding model (e.g., DDPOptimizer works with DDPModel)

Testing Strategy

For each implementation:

  1. Unit tests with InMemoryCommunicationBackend (2-4 ranks)
  2. Integration tests with small models
  3. Performance benchmarks comparing strategies
  4. Memory usage profiling
  5. Communication overhead measurements

Documentation Deliverables

For each implementation:

  1. Class documentation following project standards
  2. Usage examples in code examples
  3. Performance characteristics (memory, communication, computation)
  4. When to use decision guide
  5. Limitations and caveats

References