Table of Contents

Class ShardedModelBase<T, TInput, TOutput>

Namespace
AiDotNet.DistributedTraining
Assembly
AiDotNet.dll

Provides base implementation for distributed models with parameter sharding.

public abstract class ShardedModelBase<T, TInput, TOutput> : IShardedModel<T, TInput, TOutput>, IFullModel<T, TInput, TOutput>, IModel<TInput, TOutput, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, TInput, TOutput>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, TInput, TOutput>>, IGradientComputable<T, TInput, TOutput>, IJitCompilable<T>

Type Parameters

T

The numeric type for operations

TInput

The input type for the model

TOutput

The output type for the model

Inheritance
ShardedModelBase<T, TInput, TOutput>
Implements
IShardedModel<T, TInput, TOutput>
IFullModel<T, TInput, TOutput>
IModel<TInput, TOutput, ModelMetadata<T>>
IParameterizable<T, TInput, TOutput>
ICloneable<IFullModel<T, TInput, TOutput>>
IGradientComputable<T, TInput, TOutput>
Derived
Inherited Members
Extension Methods

Remarks

This abstract class implements common functionality for all sharded models, including parameter management, sharding logic, gradient synchronization, and integration with the model serialization system. Derived classes can customize the sharding strategy, communication pattern, or add optimization-specific features.

For Beginners: This is the foundation that all distributed models build upon.

Think of this as a template for splitting a big model across multiple computers or GPUs. It handles common tasks like:

  • Dividing model parameters into chunks (sharding)
  • Collecting all chunks when needed (gathering)
  • Sharing learning updates across all processes (gradient sync)
  • Saving and loading distributed models

Specific types of distributed models (like fully sharded or hybrid sharded) inherit from this and add their own strategies. This prevents code duplication and ensures all distributed models work consistently.

Constructors

ShardedModelBase(IFullModel<T, TInput, TOutput>, IShardingConfiguration<T>)

Initializes a new instance of the ShardedModelBase class.

protected ShardedModelBase(IFullModel<T, TInput, TOutput> wrappedModel, IShardingConfiguration<T> config)

Parameters

wrappedModel IFullModel<T, TInput, TOutput>

The model to wrap with distributed capabilities

config IShardingConfiguration<T>

Configuration for sharding and communication

Remarks

This constructor wraps an existing model with distributed training capabilities. It initializes the communication backend if needed and sets up parameter sharding.

For Beginners: This constructor takes your regular model and makes it distributed.

You provide:

  1. The model you want to distribute
  2. Configuration that tells us how to distribute it

The constructor automatically:

  • Sets up communication if not already done
  • Splits the model's parameters across processes
  • Prepares everything for distributed training

Exceptions

ArgumentNullException

Thrown if model or config is null

Fields

CachedFullParameters

Cached full parameters to avoid repeated gathering.

protected Vector<T>? CachedFullParameters

Field Value

Vector<T>

Config

The sharding configuration containing communication backend and settings.

protected readonly IShardingConfiguration<T> Config

Field Value

IShardingConfiguration<T>

LocalShard

The local parameter shard owned by this process.

protected Vector<T> LocalShard

Field Value

Vector<T>

NumOps

Provides numeric operations for type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

ShardSize

Size of this process's parameter shard.

protected int ShardSize

Field Value

int

ShardStartIndex

Starting index of this process's shard in the full parameter vector.

protected int ShardStartIndex

Field Value

int

Properties

DefaultLossFunction

Gets the default loss function used by this model for gradient computation.

public virtual ILossFunction<T> DefaultLossFunction { get; }

Property Value

ILossFunction<T>

Remarks

This loss function is used when calling ComputeGradients(TInput, TOutput, ILossFunction<T>?) without explicitly providing a loss function. It represents the model's primary training objective.

For Beginners: The loss function tells the model "what counts as a mistake". For example: - For regression (predicting numbers): Mean Squared Error measures how far predictions are from actual values - For classification (predicting categories): Cross Entropy measures how confident the model is in the right category

This property provides a sensible default so you don't have to specify the loss function every time, but you can still override it if needed for special cases.

Distributed Training: In distributed training, all workers use the same loss function to ensure consistent gradient computation. The default loss function is automatically used when workers compute local gradients.

Exceptions

InvalidOperationException

Thrown if accessed before the model has been configured with a loss function.

LocalParameterShard

Gets the portion of parameters owned by this process.

public Vector<T> LocalParameterShard { get; }

Property Value

Vector<T>

Remarks

For Beginners: This is "your piece of the puzzle" - the parameters that this particular process is responsible for storing and updating.

ParameterCount

Gets the number of parameters in the model.

public int ParameterCount { get; }

Property Value

int

Remarks

This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.

Rank

Gets the rank of this process in the distributed group.

public int Rank { get; }

Property Value

int

Remarks

For Beginners: Each process has a unique ID (rank). This tells you which process you are. Rank 0 is typically the "coordinator" process.

ShardingConfiguration

Gets the configuration for this sharded model.

public IShardingConfiguration<T> ShardingConfiguration { get; }

Property Value

IShardingConfiguration<T>

SupportsJitCompilation

Gets whether this model currently supports JIT compilation.

public virtual bool SupportsJitCompilation { get; }

Property Value

bool

True if the wrapped model supports JIT compilation, false otherwise.

Remarks

Sharded models delegate JIT compilation support to their wrapped model. JIT compilation is performed on the full model representation, not on individual shards.

For Beginners: Distributed models can be JIT compiled if the underlying model supports it.

The sharding strategy (splitting parameters across processes) doesn't prevent JIT compilation. The JIT compiler works with the full computation graph, which is the same across all processes. Individual processes execute the same compiled code but operate on different parameter shards.

WorldSize

Gets the total number of processes in the distributed group.

public int WorldSize { get; }

Property Value

int

Remarks

For Beginners: This is how many processes are working together to train the model. For example, if you have 4 GPUs, WorldSize would be 4.

WrappedModel

Gets the underlying wrapped model.

public IFullModel<T, TInput, TOutput> WrappedModel { get; }

Property Value

IFullModel<T, TInput, TOutput>

Remarks

For Beginners: This is the original model that we're adding distributed training capabilities to. Think of it as the "core brain" that we're helping to work in a distributed way.

WrappedModelInternal

Protected access to wrapped model for derived classes.

protected IFullModel<T, TInput, TOutput> WrappedModelInternal { get; }

Property Value

IFullModel<T, TInput, TOutput>

Methods

ApplyGradients(Vector<T>, T)

Applies pre-computed gradients to update the model parameters.

public virtual void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>

The gradient vector to apply.

learningRate T

The learning rate for the update.

Remarks

Updates parameters using: θ = θ - learningRate * gradients

For Beginners: After computing gradients (seeing which direction to move), this method actually moves the model in that direction. The learning rate controls how big of a step to take.

Distributed Training: In DDP/ZeRO-2, this applies the synchronized (averaged) gradients after communication across workers. Each worker applies the same averaged gradients to keep parameters consistent.

Clone()

Creates a shallow copy of this object.

public abstract IFullModel<T, TInput, TOutput> Clone()

Returns

IFullModel<T, TInput, TOutput>

ComputeGradients(TInput, TOutput, ILossFunction<T>?)

Computes gradients of the loss function with respect to model parameters for the given data, WITHOUT updating the model parameters.

public virtual Vector<T> ComputeGradients(TInput input, TOutput target, ILossFunction<T>? lossFunction = null)

Parameters

input TInput

The input data.

target TOutput

The target/expected output.

lossFunction ILossFunction<T>

The loss function to use for gradient computation. If null, uses the model's default loss function.

Returns

Vector<T>

A vector containing gradients with respect to all model parameters.

Remarks

This method performs a forward pass, computes the loss, and back-propagates to compute gradients, but does NOT update the model's parameters. The parameters remain unchanged after this call.

Distributed Training: In DDP/ZeRO-2, each worker calls this to compute local gradients on its data batch. These gradients are then synchronized (averaged) across workers before applying updates. This ensures all workers compute the same parameter updates despite having different data.

For Meta-Learning: After adapting a model on a support set, you can use this method to compute gradients on the query set. These gradients become the meta-gradients for updating the meta-parameters.

For Beginners: Think of this as "dry run" training: - The model sees what direction it should move (the gradients) - But it doesn't actually move (parameters stay the same) - You get to decide what to do with this information (average with others, inspect, modify, etc.)

Exceptions

InvalidOperationException

If lossFunction is null and the model has no default loss function.

DeepCopy()

Creates a deep copy of this object.

public virtual IFullModel<T, TInput, TOutput> DeepCopy()

Returns

IFullModel<T, TInput, TOutput>

Deserialize(byte[])

Loads a previously serialized model from binary data.

public abstract void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.

For Beginners: This is like opening a saved file to continue your work.

When you call this method:

  • You provide the binary data (bytes) that was previously created by Serialize
  • The model rebuilds itself using this data
  • After deserializing, the model is exactly as it was when serialized
  • It's ready to make predictions without needing to be trained again

For example:

  • You download a pre-trained model file for detecting spam emails
  • You deserialize this file into your application
  • Immediately, your application can detect spam without any training
  • The model has all the knowledge that was built into it by its original creator

This is particularly useful when:

  • You want to use a model that took days to train
  • You need to deploy the same model across multiple devices
  • You're creating an application that non-technical users will use

Think of it like installing the brain of a trained expert directly into your application.

ExportComputationGraph(List<ComputationNode<T>>)

Exports the computation graph for JIT compilation by delegating to the wrapped model.

public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the model's prediction.

Remarks

Sharded models delegate graph export to their wrapped model. The computation graph represents the full model's forward pass, independent of parameter sharding.

For Beginners: This creates a computation graph from the wrapped model.

Even though parameters are distributed (sharded) across multiple processes:

  • The computation graph structure is the same for all processes
  • Each process compiles the same graph into fast code
  • The only difference is which parameter values each process uses

This allows distributed models to benefit from JIT compilation while maintaining their distributed training capabilities.

Exceptions

ArgumentNullException

Thrown when inputNodes is null.

NotSupportedException

Thrown when the wrapped model does not support JIT compilation.

GatherFullParameters()

Gets the full set of parameters by gathering from all processes.

public virtual Vector<T> GatherFullParameters()

Returns

Vector<T>

The complete set of parameters gathered from all processes

Remarks

This operation involves communication across all processes.

For Beginners: This is like asking everyone to share their puzzle pieces so you can see the complete picture. It requires communication between all processes, so it's more expensive than just accessing LocalParameterShard.

GetActiveFeatureIndices()

Gets the indices of features that are actively used by this model.

public virtual IEnumerable<int> GetActiveFeatureIndices()

Returns

IEnumerable<int>

GetFeatureImportance()

Gets the feature importance scores.

public virtual Dictionary<string, T> GetFeatureImportance()

Returns

Dictionary<string, T>

GetModelMetadata()

Retrieves metadata and performance metrics about the trained model.

public abstract ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

An object containing metadata and performance metrics about the trained model.

Remarks

This method provides information about the model's structure, parameters, and performance metrics.

For Beginners: Model metadata is like a report card for your machine learning model.

Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.

This information typically includes:

  • Accuracy measures: How well does the model's predictions match actual values?
  • Error metrics: How far off are the model's predictions on average?
  • Model parameters: What patterns did the model learn from the data?
  • Training information: How long did training take? How many iterations were needed?

For example, in a house price prediction model, metadata might include:

  • Average prediction error (e.g., off by $15,000 on average)
  • How strongly each feature (bedrooms, location) influences the prediction
  • How well the model fits the training data

This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.

GetParameters()

Gets the parameters that can be optimized.

public virtual Vector<T> GetParameters()

Returns

Vector<T>

InitializeSharding()

Initializes parameter sharding by dividing parameters across processes.

protected virtual void InitializeSharding()

Remarks

This method calculates how to distribute parameters evenly across all processes, with remainder parameters distributed to the first few processes. Derived classes can override this to implement different sharding strategies.

For Beginners: This splits the model's parameters across all processes.

Think of it like dividing a deck of cards among players. If you have 10 parameters and 3 processes:

  • Process 0 gets parameters 0-3 (4 parameters)
  • Process 1 gets parameters 4-6 (3 parameters)
  • Process 2 gets parameters 7-9 (3 parameters)

We try to split evenly, but if there's a remainder, the first processes get one extra parameter each.

InvalidateCache()

Invalidates the cached full parameters, forcing a re-gather on next access.

protected void InvalidateCache()

Remarks

This method should be called whenever local parameters change to ensure the cache is refreshed on the next GatherFullParameters call.

For Beginners: When parameters change, we need to throw away the old cached full parameters.

It's like when you update a document - you need to discard the old saved copy so that next time you need it, you get the updated version.

IsFeatureUsed(int)

Checks if a specific feature is used by this model.

public virtual bool IsFeatureUsed(int featureIndex)

Parameters

featureIndex int

Returns

bool

LoadModel(string)

Loads the model from a file.

public abstract void LoadModel(string filePath)

Parameters

filePath string

The path to the file containing the saved model.

Remarks

This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.

For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.

Exceptions

FileNotFoundException

Thrown when the specified file does not exist.

IOException

Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.

LoadState(Stream)

Loads the model's state from a stream.

public virtual void LoadState(Stream stream)

Parameters

stream Stream

OnBeforeInitializeSharding()

Called before InitializeSharding to allow derived classes to set up state.

protected virtual void OnBeforeInitializeSharding()

Remarks

Override this method in derived classes to initialize fields that are needed by InitializeSharding but cannot be set before the base constructor call.

Predict(TInput)

Uses the trained model to make predictions for new input data.

public abstract TOutput Predict(TInput input)

Parameters

input TInput

A matrix where each row represents a new example to predict and each column represents a feature.

Returns

TOutput

A vector containing the predicted values for each input example.

Remarks

After training, this method applies the learned patterns to new data to predict outcomes.

For Beginners: Prediction is when the model uses what it learned to make educated guesses about new information.

Continuing the fruit identification example:

  • After learning from many examples, the child (model) can now identify new fruits they haven't seen before
  • They look at the color, shape, and size to make their best guess

In machine learning:

  • You give the model new data it hasn't seen during training
  • The model applies the patterns it learned to make predictions
  • The output is the model's best estimate based on its training

For example, in a house price prediction model:

  • You provide features of a new house (square footage, bedrooms, location)
  • The model predicts what price that house might sell for

This method is used after training is complete, when you want to apply your model to real-world data.

SaveModel(string)

Saves the model to a file.

public abstract void SaveModel(string filePath)

Parameters

filePath string

The path where the model should be saved.

Remarks

This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.

For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.

Exceptions

IOException

Thrown when an I/O error occurs while writing to the file.

UnauthorizedAccessException

Thrown when the caller does not have the required permission to write to the specified file path.

SaveState(Stream)

Saves the model's current state to a stream.

public virtual void SaveState(Stream stream)

Parameters

stream Stream

Serialize()

Converts the current state of a machine learning model into a binary format.

public abstract byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.

For Beginners: This is like exporting your work to a file.

When you call this method:

  • The model's current state (all its learned patterns and parameters) is captured
  • This information is converted into a compact binary format (bytes)
  • You can then save these bytes to a file, database, or send them over a network

For example:

  • After training a model to recognize cats vs. dogs in images
  • You can serialize the model to save all its learned knowledge
  • Later, you can use this saved data to recreate the model exactly as it was
  • The recreated model will make the same predictions as the original

Think of it like taking a snapshot of your model's brain at a specific moment in time.

SetActiveFeatureIndices(IEnumerable<int>)

Sets the active feature indices for this model.

public virtual void SetActiveFeatureIndices(IEnumerable<int> featureIndices)

Parameters

featureIndices IEnumerable<int>

SetParameters(Vector<T>)

Sets the model parameters.

public virtual void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to set.

Remarks

This method allows direct modification of the model's internal parameters. This is useful for optimization algorithms that need to update parameters iteratively. If the length of parameters does not match ParameterCount, an ArgumentException should be thrown.

Exceptions

ArgumentException

Thrown when the length of parameters does not match ParameterCount.

SynchronizeGradients()

Synchronizes gradients across all processes using AllReduce.

public virtual void SynchronizeGradients()

Remarks

After this operation, all processes have the same averaged gradients.

For Beginners: During training, each process calculates gradients based on its portion of the data. This method combines (averages) those gradients so that everyone is learning from everyone else's experiences. It's like a team meeting where everyone shares what they learned.

Train(TInput, TOutput)

Trains the model using input features and their corresponding target values.

public abstract void Train(TInput input, TOutput expectedOutput)

Parameters

input TInput
expectedOutput TOutput

Remarks

This method takes training data and adjusts the model's internal parameters to learn patterns in the data.

For Beginners: Training is like teaching the model by showing it examples.

Imagine teaching a child to identify fruits:

  • You show them many examples of apples, oranges, and bananas (input features x)
  • You tell them the correct name for each fruit (target values y)
  • Over time, they learn to recognize the patterns that distinguish each fruit

In machine learning:

  • The x parameter contains features (characteristics) of your data
  • The y parameter contains the correct answers you want the model to learn
  • During training, the model adjusts its internal calculations to get better at predicting y from x

For example, in a house price prediction model:

  • x would contain features like square footage, number of bedrooms, location
  • y would contain the actual sale prices of those houses

UpdateLocalShardFromFull(Vector<T>)

Updates the local parameter shard from the full parameter vector.

protected void UpdateLocalShardFromFull(Vector<T> fullParameters)

Parameters

fullParameters Vector<T>

The full parameter vector

Remarks

This method extracts this process's shard from a full parameter vector. Used after training updates or when setting parameters.

For Beginners: After the full model is updated, we need to extract our piece of it.

It's like taking your slice of a pizza after it's been prepared - you get the portion that belongs to you from the whole.

WithParameters(Vector<T>)

Creates a new instance with the specified parameters.

public abstract IFullModel<T, TInput, TOutput> WithParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

Returns

IFullModel<T, TInput, TOutput>