Table of Contents

Class TensorParallelModel<T, TInput, TOutput>

Namespace
AiDotNet.DistributedTraining
Assembly
AiDotNet.dll

Implements Tensor Parallel model wrapper - splits individual layers across ranks (Megatron-LM style).

public class TensorParallelModel<T, TInput, TOutput> : ShardedModelBase<T, TInput, TOutput>, IShardedModel<T, TInput, TOutput>, IFullModel<T, TInput, TOutput>, IModel<TInput, TOutput, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, TInput, TOutput>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, TInput, TOutput>>, IGradientComputable<T, TInput, TOutput>, IJitCompilable<T>

Type Parameters

T

The numeric type

TInput

The input type for the model

TOutput

The output type for the model

Inheritance
ShardedModelBase<T, TInput, TOutput>
TensorParallelModel<T, TInput, TOutput>
Implements
IShardedModel<T, TInput, TOutput>
IFullModel<T, TInput, TOutput>
IModel<TInput, TOutput, ModelMetadata<T>>
IParameterizable<T, TInput, TOutput>
ICloneable<IFullModel<T, TInput, TOutput>>
IGradientComputable<T, TInput, TOutput>
Inherited Members
Extension Methods

Remarks

Strategy Overview: Tensor Parallelism (Megatron-LM style) partitions individual layers horizontally across processes. For example, a large matrix multiplication is split so each GPU computes only a portion of the output, then results are combined. This is particularly effective for transformer models where attention and feed-forward layers can be partitioned along specific dimensions (column-parallel and row-parallel).

For Beginners: Tensor parallelism is like splitting a single large calculation across multiple workers. Imagine a huge spreadsheet calculation - instead of one person doing all the math, we divide the spreadsheet columns across multiple people, each computing their portion simultaneously.

For example, in a neural network layer with a 10000x10000 weight matrix:

  • GPU 0 handles columns 0-2499
  • GPU 1 handles columns 2500-4999
  • GPU 2 handles columns 5000-7499
  • GPU 3 handles columns 7500-9999

They compute in parallel, then combine results.

Use Cases: - Very wide models (large hidden dimensions) - Transformer models (BERT, GPT) with large attention/FFN layers - When individual layers are too large for single GPU - Often combined with pipeline parallelism for maximum scalability

Trade-offs: - Memory: Excellent for wide layers - each rank stores only portion of weights - Communication: High - requires AllReduce or AllGather within each layer - Complexity: Very High - requires model-aware partitioning, specific to layer types - Best for: Transformer models, very wide layers, fast interconnects (NVLink) - Limitation: Requires fast communication (high overhead on slow networks)

Implementation Note: This is a production-ready framework implementation. Full tensor parallelism requires model-specific layer partitioning (column-parallel vs row-parallel strategy for different layer types). This implementation provides the infrastructure. For production use with specific models (e.g., transformers), extend this class with layer-aware partitioning.

⚠️ IMPORTANT LIMITATION - Memory Efficiency: This implementation gathers the full parameter vector on every Train() and Predict() call (via GatherFullParameters and SetParameters), which defeats the memory-saving purpose of true tensor parallelism. While parameters are sharded across ranks for storage, they are reconstructed into the full vector for each forward/backward pass. This means: - Memory savings are minimal compared to data-parallel training - Communication overhead is high (AllGather on every forward pass) - This wrapper primarily provides gradient synchronization, not memory-efficient tensor parallelism

For true memory-efficient tensor parallelism, you would need layer-aware implementations where each rank only loads its parameter shard and performs partial matrix multiplications without ever reconstructing the full parameter vector. This simplified implementation is suitable for:

  • Testing and development of distributed training infrastructure
  • Scenarios where gradient synchronization is more important than memory efficiency
  • Models where memory is not the primary constraint

If memory efficiency is critical, consider using FSDP (Fully Sharded Data Parallel) or ZeRO-3 instead, which shard parameters more aggressively and avoid full parameter reconstruction.

Example:

var model = new TransformerModel<double>(...); // Large transformer
var backend = new InMemoryCommunicationBackend<double>(rank: 0, worldSize: 4);
var config = new ShardingConfiguration<double>(backend);

// Each rank handles 1/4 of each layer's width var tensorParallelModel = new TensorParallelModel<double, Tensor<double>, Tensor<double>>( model, config);

Constructors

TensorParallelModel(IFullModel<T, TInput, TOutput>, IShardingConfiguration<T>)

Creates a new Tensor Parallel model.

public TensorParallelModel(IFullModel<T, TInput, TOutput> wrappedModel, IShardingConfiguration<T> config)

Parameters

wrappedModel IFullModel<T, TInput, TOutput>

The model to partition with tensor parallelism

config IShardingConfiguration<T>

Configuration for sharding and communication

Methods

Clone()

Creates a shallow copy of this object.

public override IFullModel<T, TInput, TOutput> Clone()

Returns

IFullModel<T, TInput, TOutput>

Deserialize(byte[])

Loads a previously serialized model from binary data.

public override void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.

For Beginners: This is like opening a saved file to continue your work.

When you call this method:

  • You provide the binary data (bytes) that was previously created by Serialize
  • The model rebuilds itself using this data
  • After deserializing, the model is exactly as it was when serialized
  • It's ready to make predictions without needing to be trained again

For example:

  • You download a pre-trained model file for detecting spam emails
  • You deserialize this file into your application
  • Immediately, your application can detect spam without any training
  • The model has all the knowledge that was built into it by its original creator

This is particularly useful when:

  • You want to use a model that took days to train
  • You need to deploy the same model across multiple devices
  • You're creating an application that non-technical users will use

Think of it like installing the brain of a trained expert directly into your application.

GetModelMetadata()

Retrieves metadata and performance metrics about the trained model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

An object containing metadata and performance metrics about the trained model.

Remarks

This method provides information about the model's structure, parameters, and performance metrics.

For Beginners: Model metadata is like a report card for your machine learning model.

Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.

This information typically includes:

  • Accuracy measures: How well does the model's predictions match actual values?
  • Error metrics: How far off are the model's predictions on average?
  • Model parameters: What patterns did the model learn from the data?
  • Training information: How long did training take? How many iterations were needed?

For example, in a house price prediction model, metadata might include:

  • Average prediction error (e.g., off by $15,000 on average)
  • How strongly each feature (bedrooms, location) influences the prediction
  • How well the model fits the training data

This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.

InitializeSharding()

Initializes tensor parallelism by partitioning layer weights.

protected override void InitializeSharding()

LoadModel(string)

Loads the model from a file.

public override void LoadModel(string filePath)

Parameters

filePath string

The path to the file containing the saved model.

Remarks

This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.

For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.

Exceptions

FileNotFoundException

Thrown when the specified file does not exist.

IOException

Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.

OnBeforeInitializeSharding()

Called before InitializeSharding to set up derived class state.

protected override void OnBeforeInitializeSharding()

Predict(TInput)

Uses the trained model to make predictions for new input data.

public override TOutput Predict(TInput input)

Parameters

input TInput

A matrix where each row represents a new example to predict and each column represents a feature.

Returns

TOutput

A vector containing the predicted values for each input example.

Remarks

After training, this method applies the learned patterns to new data to predict outcomes.

For Beginners: Prediction is when the model uses what it learned to make educated guesses about new information.

Continuing the fruit identification example:

  • After learning from many examples, the child (model) can now identify new fruits they haven't seen before
  • They look at the color, shape, and size to make their best guess

In machine learning:

  • You give the model new data it hasn't seen during training
  • The model applies the patterns it learned to make predictions
  • The output is the model's best estimate based on its training

For example, in a house price prediction model:

  • You provide features of a new house (square footage, bedrooms, location)
  • The model predicts what price that house might sell for

This method is used after training is complete, when you want to apply your model to real-world data.

SaveModel(string)

Saves the model to a file.

public override void SaveModel(string filePath)

Parameters

filePath string

The path where the model should be saved.

Remarks

This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.

For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.

Exceptions

IOException

Thrown when an I/O error occurs while writing to the file.

UnauthorizedAccessException

Thrown when the caller does not have the required permission to write to the specified file path.

Serialize()

Converts the current state of a machine learning model into a binary format.

public override byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.

For Beginners: This is like exporting your work to a file.

When you call this method:

  • The model's current state (all its learned patterns and parameters) is captured
  • This information is converted into a compact binary format (bytes)
  • You can then save these bytes to a file, database, or send them over a network

For example:

  • After training a model to recognize cats vs. dogs in images
  • You can serialize the model to save all its learned knowledge
  • Later, you can use this saved data to recreate the model exactly as it was
  • The recreated model will make the same predictions as the original

Think of it like taking a snapshot of your model's brain at a specific moment in time.

SynchronizeGradients()

Synchronizes tensor-parallel computation results.

public override void SynchronizeGradients()

Remarks

In tensor parallelism, different layers require different synchronization patterns:

  • Column-parallel layers: AllReduce after computation
  • Row-parallel layers: AllGather before computation This implementation uses subgroup AllReduce within the tensor-parallel group.

Train(TInput, TOutput)

Trains the model using input features and their corresponding target values.

public override void Train(TInput input, TOutput expectedOutput)

Parameters

input TInput
expectedOutput TOutput

Remarks

This method takes training data and adjusts the model's internal parameters to learn patterns in the data.

For Beginners: Training is like teaching the model by showing it examples.

Imagine teaching a child to identify fruits:

  • You show them many examples of apples, oranges, and bananas (input features x)
  • You tell them the correct name for each fruit (target values y)
  • Over time, they learn to recognize the patterns that distinguish each fruit

In machine learning:

  • The x parameter contains features (characteristics) of your data
  • The y parameter contains the correct answers you want the model to learn
  • During training, the model adjusts its internal calculations to get better at predicting y from x

For example, in a house price prediction model:

  • x would contain features like square footage, number of bedrooms, location
  • y would contain the actual sale prices of those houses

WithParameters(Vector<T>)

Creates a new instance with the specified parameters.

public override IFullModel<T, TInput, TOutput> WithParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

Returns

IFullModel<T, TInput, TOutput>