Class InMemoryCommunicationBackend<T>
- Namespace
- AiDotNet.DistributedTraining
- Assembly
- AiDotNet.dll
Provides an in-memory implementation of distributed communication for testing and single-machine scenarios.
public class InMemoryCommunicationBackend<T> : CommunicationBackendBase<T>, ICommunicationBackend<T>
Type Parameters
TThe numeric type for operations
- Inheritance
-
InMemoryCommunicationBackend<T>
- Implements
- Inherited Members
Remarks
⚠️ WARNING - Static Shared State: This implementation uses STATIC shared dictionaries to simulate cross-process communication. This design has important implications:
- All instances in the same process share the SAME static state
- Unit tests using this backend CANNOT run in parallel without isolation via environmentId
- Multiple training sessions in the same process can interfere unless using unique environmentIds
- NOT suitable for production multi-process scenarios - use MPI/NCCL backends instead
The static state includes: _sharedBuffers, _barrierCounters, _barrierGenerations, _operationCounters, _messageQueues. These are namespaced by environmentId to enable concurrent independent sessions, but tests must ensure unique environmentIds or run serially.
This backend simulates multiple processes by using shared memory and locks. It's perfect for testing distributed code without needing actual MPI infrastructure or multiple machines. All "processes" run within the same application instance, using static shared memory to simulate cross-process communication.
For Beginners: This is a "fake" distributed system that runs on a single machine.
It's perfect for testing your distributed code without needing multiple GPUs or machines. Think of it as a practice mode - it simulates distributed behavior but everything runs in one process.
Use this when:
- Testing distributed code locally
- Debugging distributed training logic
- Running unit tests
- Learning how distributed training works
For production with actual multiple GPUs/machines, use an MPI-based backend instead.
Example:
// Create a simulated distributed environment with 4 "processes"
var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
backend.Initialize();
// Now you can test distributed operations locally
var data = new Vector<double>(new[] { 1.0, 2.0, 3.0 });
backend.AllReduce(data, ReductionOperation.Sum);
// data now contains the sum from all 4 simulated processes
Constructors
InMemoryCommunicationBackend(int, int, string)
Creates a new in-memory communication backend.
public InMemoryCommunicationBackend(int rank, int worldSize, string environmentId = "default")
Parameters
rankintThe rank (ID) of this simulated process (0-based)
worldSizeintThe total number of simulated processes
environmentIdstringOptional environment ID for isolation (defaults to "default" for backwards compatibility)
Remarks
You create one of these for each simulated "process". If you want to simulate 4 GPUs, you create 4 instances with ranks 0, 1, 2, 3, all with worldSize=4.
For Beginners: This creates one simulated process in your fake distributed system.
Parameters:
- rank: The ID of this process (0-based). Each process needs a unique rank.
- worldSize: How many processes total are in your simulated system.
Example: To simulate 4 GPUs, create 4 backends:
var process0 = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
var process1 = new InMemoryCommunicationBackend<double>(rank: 1, worldSize: 4);
var process2 = new InMemoryCommunicationBackend<double>(rank: 2, worldSize: 4);
var process3 = new InMemoryCommunicationBackend<double>(rank: 3, worldSize: 4);
Exceptions
- ArgumentException
Thrown if rank or worldSize are invalid
Properties
Rank
Gets the rank (ID) of the current process in the distributed group.
public override int Rank { get; }
Property Value
Remarks
Rank 0 is typically the "master" or "coordinator" process.
For Beginners: Think of rank as your process's unique ID number. If you have 4 GPUs, ranks will be 0, 1, 2, and 3. Rank 0 is usually the "boss" that coordinates everything.
WorldSize
Gets the total number of processes in the distributed group.
public override int WorldSize { get; }
Property Value
Remarks
For Beginners: This is how many processes (or GPUs) are working together. If WorldSize is 4, you have 4 processes sharing the work.
Methods
AllGather(Vector<T>)
AllGather operation - gathers data from all processes and concatenates it.
public override Vector<T> AllGather(Vector<T> sendData)
Parameters
sendDataVector<T>The local data to contribute
Returns
- Vector<T>
The gathered data from all processes concatenated together
Remarks
Each process receives the complete concatenated result.
For Beginners: If GPU 0 has [1,2], GPU 1 has [3,4], GPU 2 has [5,6], GPU 3 has [7,8], then AllGather gives everyone [1,2,3,4,5,6,7,8]. This is used to reconstruct the full model parameters from sharded pieces.
AllReduce(Vector<T>, ReductionOperation)
AllReduce operation - combines data from all processes using the specified operation and distributes the result back to all processes.
public override void AllReduce(Vector<T> data, ReductionOperation operation)
Parameters
dataVector<T>The data to reduce. Will be replaced with the reduced result.
operationReductionOperationThe reduction operation (Sum, Max, Min, etc.)
Remarks
For Beginners: Imagine 4 GPUs each calculated a gradient vector. AllReduce takes all 4 vectors, adds them together (if operation is Sum), and gives the result to all 4 GPUs. This is crucial for averaging gradients across GPUs during training.
Common operations: - Sum: Add all values together (used for gradient averaging) - Max: Take the maximum value across all processes - Min: Take the minimum value across all processes
Barrier()
Synchronization barrier - blocks until all processes reach this point.
public override void Barrier()
Remarks
For Beginners: This is like a meeting checkpoint. All processes must arrive at this point before any of them can continue. It ensures everyone is synchronized. Example: Before starting training, you want all GPUs to be ready.
Broadcast(Vector<T>, int)
Broadcast operation - sends data from one process (root) to all other processes.
public override Vector<T> Broadcast(Vector<T> data, int root = 0)
Parameters
dataVector<T>The data to broadcast (only meaningful on root process)
rootintThe rank of the process that is broadcasting
Returns
- Vector<T>
The broadcast data (received from root on non-root processes)
Remarks
For Beginners: This is like an announcement from the boss (root process). The root sends data to everyone else. Useful for distributing initial parameters or configurations.
ClearEnvironment(string)
Clears all shared state for a specific environment. Useful for test cleanup and isolation.
public static void ClearEnvironment(string environmentId)
Parameters
environmentIdstringThe environment ID to clear
OnInitialize()
Called during initialization to perform backend-specific setup.
protected override void OnInitialize()
Remarks
Derived classes override this method to implement their specific initialization logic, such as connecting to MPI or setting up shared memory structures.
For Beginners: This is where each specific backend does its setup work.
For example:
- An MPI backend would connect to the MPI environment
- An in-memory backend would create shared data structures
- An NCCL backend would initialize GPU communication channels
OnShutdown()
Called during shutdown to perform backend-specific cleanup.
protected override void OnShutdown()
Remarks
Derived classes override this method to implement their specific cleanup logic, such as disconnecting from MPI or releasing shared memory.
For Beginners: This is where each backend cleans up its resources.
It's like turning off equipment when you're done - releasing memory, closing connections, and ensuring everything shuts down cleanly.
Receive(int, int, int)
Receive operation - receives data from a specific source process.
public override Vector<T> Receive(int sourceRank, int count, int tag = 0)
Parameters
sourceRankintThe rank of the process to receive from
countintThe expected number of elements to receive
tagintOptional message tag to match with Send (default=0)
Returns
- Vector<T>
The received data
Remarks
This is a point-to-point communication operation that blocks until data arrives.
For Beginners: This is like waiting for a private message from a specific GPU. The process will wait (block) until the message arrives.
Use cases:
- Pipeline parallelism: receiving activations from previous stage
- Ring-based algorithms: receiving data from neighbor
- Custom communication patterns
Important: Receive must be matched with a corresponding Send from the source process. If the sender never sends, this will deadlock (hang forever). If the sizes don't match, data corruption or errors can occur.
ReduceScatter(Vector<T>, ReductionOperation)
ReduceScatter operation - reduces data and scatters the result.
public override Vector<T> ReduceScatter(Vector<T> data, ReductionOperation operation)
Parameters
dataVector<T>The data to reduce and scatter
operationReductionOperationThe reduction operation
Returns
- Vector<T>
The reduced chunk for this process
Remarks
Combines AllReduce and Scatter in one operation for efficiency.
For Beginners: This is an optimization that combines reduction and scattering. Instead of doing AllReduce (everyone gets everything) then Scatter (split it up), we directly compute and distribute only the needed chunks.
Scatter(Vector<T>, int)
Scatter operation - distributes different chunks of data from root to each process.
public override Vector<T> Scatter(Vector<T> sendData, int root = 0)
Parameters
sendDataVector<T>The data to scatter (only used on root process)
rootintThe rank of the process that is scattering
Returns
- Vector<T>
The chunk of data received by this process
Remarks
For Beginners: The root has a big array and wants to give each GPU a different piece. If root has [1,2,3,4,5,6,7,8] and WorldSize=4, it gives: GPU 0 gets [1,2], GPU 1 gets [3,4], GPU 2 gets [5,6], GPU 3 gets [7,8]
Send(Vector<T>, int, int)
Send operation - sends data from this process to a specific destination process.
public override void Send(Vector<T> data, int destinationRank, int tag = 0)
Parameters
dataVector<T>The data to send
destinationRankintThe rank of the process to send to
tagintOptional message tag to distinguish different messages (default=0)
Remarks
This is a point-to-point communication operation. Unlike collective operations (AllReduce, Broadcast, etc.), only two processes are involved: sender and receiver.
For Beginners: This is like sending a private message to one specific GPU. Unlike Broadcast (which sends to everyone), Send only sends to one receiver.
Use cases:
- Pipeline parallelism: sending activations from one stage to the next
- Ring-based algorithms: sending data to neighbor in a ring
- Custom communication patterns
Important: Send must be matched with a corresponding Receive on the destination process. The sender and receiver must agree on the message size, otherwise deadlock or incorrect data transfer can occur.