Table of Contents

Class CommunicationBackendBase<T>

Namespace
AiDotNet.DistributedTraining
Assembly
AiDotNet.dll

Provides base implementation for distributed communication backends.

public abstract class CommunicationBackendBase<T> : ICommunicationBackend<T>

Type Parameters

T

The numeric type for operations

Inheritance
CommunicationBackendBase<T>
Implements
Derived
Inherited Members

Remarks

This abstract class implements common functionality for all communication backends, including state management, validation, and helper methods for collective operations. Derived classes implement the specific communication mechanisms (MPI, NCCL, in-memory, etc.).

For Beginners: This is the foundation that all communication systems build upon.

Think of this as a template that defines how any communication system should work. It handles common tasks like:

  • Keeping track of whether the system is initialized
  • Validating inputs (checking for null values, correct sizes, etc.)
  • Providing helper methods for common operations

Specific communication backends (like MPI or in-memory) inherit from this and add their own implementation details. This prevents code duplication and ensures all backends work consistently.

Constructors

CommunicationBackendBase()

Initializes a new instance of the CommunicationBackendBase class.

protected CommunicationBackendBase()

Remarks

This constructor sets up the numeric operations provider that will be used for all mathematical operations on type T.

For Beginners: This constructor is called when creating any communication backend.

It sets up the math helper that allows the backend to perform operations like addition, multiplication, and comparison on any numeric type (double, float, etc.) without knowing in advance which type will be used.

Fields

NumOps

Provides numeric operations for type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

IsInitialized

Gets whether this backend is initialized and ready for use.

public bool IsInitialized { get; }

Property Value

bool

Rank

Gets the rank (ID) of the current process in the distributed group.

public abstract int Rank { get; }

Property Value

int

Remarks

Rank 0 is typically the "master" or "coordinator" process.

For Beginners: Think of rank as your process's unique ID number. If you have 4 GPUs, ranks will be 0, 1, 2, and 3. Rank 0 is usually the "boss" that coordinates everything.

WorldSize

Gets the total number of processes in the distributed group.

public abstract int WorldSize { get; }

Property Value

int

Remarks

For Beginners: This is how many processes (or GPUs) are working together. If WorldSize is 4, you have 4 processes sharing the work.

Methods

AllGather(Vector<T>)

AllGather operation - gathers data from all processes and concatenates it.

public abstract Vector<T> AllGather(Vector<T> sendData)

Parameters

sendData Vector<T>

The local data to contribute

Returns

Vector<T>

The gathered data from all processes concatenated together

Remarks

Each process receives the complete concatenated result.

For Beginners: If GPU 0 has [1,2], GPU 1 has [3,4], GPU 2 has [5,6], GPU 3 has [7,8], then AllGather gives everyone [1,2,3,4,5,6,7,8]. This is used to reconstruct the full model parameters from sharded pieces.

AllReduce(Vector<T>, ReductionOperation)

AllReduce operation - combines data from all processes using the specified operation and distributes the result back to all processes.

public abstract void AllReduce(Vector<T> data, ReductionOperation operation)

Parameters

data Vector<T>

The data to reduce. Will be replaced with the reduced result.

operation ReductionOperation

The reduction operation (Sum, Max, Min, etc.)

Remarks

For Beginners: Imagine 4 GPUs each calculated a gradient vector. AllReduce takes all 4 vectors, adds them together (if operation is Sum), and gives the result to all 4 GPUs. This is crucial for averaging gradients across GPUs during training.

Common operations: - Sum: Add all values together (used for gradient averaging) - Max: Take the maximum value across all processes - Min: Take the minimum value across all processes

ApplyReductionOperation(T, T, ReductionOperation)

Applies a reduction operation to two values.

protected T ApplyReductionOperation(T a, T b, ReductionOperation operation)

Parameters

a T

The first value

b T

The second value

operation ReductionOperation

The reduction operation to apply

Returns

T

The result of applying the operation

Remarks

This helper method performs the specified reduction operation (Sum, Product, Min, Max) on two values. It's used internally by AllReduce and ReduceScatter implementations.

For Beginners: This is a helper that knows how to combine two numbers in different ways.

For example:

  • Sum operation: 3 + 5 = 8
  • Product operation: 3 * 5 = 15
  • Min operation: Min(3, 5) = 3
  • Max operation: Max(3, 5) = 5

We use this when combining values from multiple processes.

Barrier()

Synchronization barrier - blocks until all processes reach this point.

public abstract void Barrier()

Remarks

For Beginners: This is like a meeting checkpoint. All processes must arrive at this point before any of them can continue. It ensures everyone is synchronized. Example: Before starting training, you want all GPUs to be ready.

Broadcast(Vector<T>, int)

Broadcast operation - sends data from one process (root) to all other processes.

public abstract Vector<T> Broadcast(Vector<T> data, int root = 0)

Parameters

data Vector<T>

The data to broadcast (only meaningful on root process)

root int

The rank of the process that is broadcasting

Returns

Vector<T>

The broadcast data (received from root on non-root processes)

Remarks

For Beginners: This is like an announcement from the boss (root process). The root sends data to everyone else. Useful for distributing initial parameters or configurations.

EnsureInitialized()

Ensures the backend is initialized before performing operations.

protected void EnsureInitialized()

Remarks

This method throws an exception if the backend has not been initialized. All communication operations should call this before proceeding.

For Beginners: This is a safety check.

Before doing any communication, we make sure the system is ready. If someone tries to use the backend without initializing it first, this method will throw an error with a helpful message.

Exceptions

InvalidOperationException

Thrown if the backend is not initialized

Initialize()

Initializes the communication backend.

public virtual void Initialize()

Remarks

Must be called before any other operations.

For Beginners: This is like turning on your walkie-talkie system. You need to do this once at the start before any processes can talk to each other.

OnInitialize()

Called during initialization to perform backend-specific setup.

protected virtual void OnInitialize()

Remarks

Derived classes override this method to implement their specific initialization logic, such as connecting to MPI or setting up shared memory structures.

For Beginners: This is where each specific backend does its setup work.

For example:

  • An MPI backend would connect to the MPI environment
  • An in-memory backend would create shared data structures
  • An NCCL backend would initialize GPU communication channels

OnShutdown()

Called during shutdown to perform backend-specific cleanup.

protected virtual void OnShutdown()

Remarks

Derived classes override this method to implement their specific cleanup logic, such as disconnecting from MPI or releasing shared memory.

For Beginners: This is where each backend cleans up its resources.

It's like turning off equipment when you're done - releasing memory, closing connections, and ensuring everything shuts down cleanly.

Receive(int, int, int)

Receive operation - receives data from a specific source process.

public abstract Vector<T> Receive(int sourceRank, int count, int tag = 0)

Parameters

sourceRank int

The rank of the process to receive from

count int

The expected number of elements to receive

tag int

Optional message tag to match with Send (default=0)

Returns

Vector<T>

The received data

Remarks

This is a point-to-point communication operation that blocks until data arrives.

For Beginners: This is like waiting for a private message from a specific GPU. The process will wait (block) until the message arrives.

Use cases:

  • Pipeline parallelism: receiving activations from previous stage
  • Ring-based algorithms: receiving data from neighbor
  • Custom communication patterns

Important: Receive must be matched with a corresponding Send from the source process. If the sender never sends, this will deadlock (hang forever). If the sizes don't match, data corruption or errors can occur.

ReduceScatter(Vector<T>, ReductionOperation)

ReduceScatter operation - reduces data and scatters the result.

public abstract Vector<T> ReduceScatter(Vector<T> data, ReductionOperation operation)

Parameters

data Vector<T>

The data to reduce and scatter

operation ReductionOperation

The reduction operation

Returns

Vector<T>

The reduced chunk for this process

Remarks

Combines AllReduce and Scatter in one operation for efficiency.

For Beginners: This is an optimization that combines reduction and scattering. Instead of doing AllReduce (everyone gets everything) then Scatter (split it up), we directly compute and distribute only the needed chunks.

Scatter(Vector<T>, int)

Scatter operation - distributes different chunks of data from root to each process.

public abstract Vector<T> Scatter(Vector<T> sendData, int root = 0)

Parameters

sendData Vector<T>

The data to scatter (only used on root process)

root int

The rank of the process that is scattering

Returns

Vector<T>

The chunk of data received by this process

Remarks

For Beginners: The root has a big array and wants to give each GPU a different piece. If root has [1,2,3,4,5,6,7,8] and WorldSize=4, it gives: GPU 0 gets [1,2], GPU 1 gets [3,4], GPU 2 gets [5,6], GPU 3 gets [7,8]

Send(Vector<T>, int, int)

Send operation - sends data from this process to a specific destination process.

public abstract void Send(Vector<T> data, int destinationRank, int tag = 0)

Parameters

data Vector<T>

The data to send

destinationRank int

The rank of the process to send to

tag int

Optional message tag to distinguish different messages (default=0)

Remarks

This is a point-to-point communication operation. Unlike collective operations (AllReduce, Broadcast, etc.), only two processes are involved: sender and receiver.

For Beginners: This is like sending a private message to one specific GPU. Unlike Broadcast (which sends to everyone), Send only sends to one receiver.

Use cases:

  • Pipeline parallelism: sending activations from one stage to the next
  • Ring-based algorithms: sending data to neighbor in a ring
  • Custom communication patterns

Important: Send must be matched with a corresponding Receive on the destination process. The sender and receiver must agree on the message size, otherwise deadlock or incorrect data transfer can occur.

Shutdown()

Shuts down the communication backend and releases resources. Should be called when distributed training is complete.

public virtual void Shutdown()

ValidateData(Vector<T>?, string)

Validates that data is not null.

protected void ValidateData(Vector<T>? data, string paramName)

Parameters

data Vector<T>

The data to validate

paramName string

The parameter name for error messages

Remarks

This method ensures that the data vector is not null before attempting communication operations.

For Beginners: This is a basic safety check to make sure we're not trying to send or receive null data, which would cause errors.

Exceptions

ArgumentNullException

Thrown if data is null

ValidateRank(int, string)

Validates that a rank is within valid bounds.

protected void ValidateRank(int rank, string paramName)

Parameters

rank int

The rank to validate

paramName string

The parameter name for error messages

Remarks

This method ensures that the specified rank is a valid process ID (between 0 and WorldSize - 1) and different from the current process rank.

For Beginners: When sending or receiving from another process, we need to make sure: 1. That process actually exists (valid rank number) 2. We're not trying to send/receive from ourselves (rank != Rank)

For example, if you have 4 processes (ranks 0-3), rank 5 would be invalid. Also, a process shouldn't send to itself.

Exceptions

ArgumentException

Thrown if rank is out of bounds or equals current rank

ValidateRoot(int)

Validates that a root rank is within valid bounds.

protected void ValidateRoot(int root)

Parameters

root int

The root rank to validate

Remarks

This method ensures that the specified root rank is a valid process ID (between 0 and WorldSize - 1).

For Beginners: When one process acts as the "root" or "leader", we need to make sure that process actually exists.

For example, if you have 4 processes (ranks 0-3), specifying rank 5 as the root would be an error. This method catches such mistakes.

Exceptions

ArgumentException

Thrown if root is out of bounds