Class ModelShard<T>
Enables model sharding across multiple devices for large model inference.
public class ModelShard<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
ModelShard<T>
- Inherited Members
Remarks
Model sharding (also known as model parallelism or pipeline parallelism) splits a large model across multiple devices (GPUs/CPUs) when it's too large to fit on a single device.
For Beginners: Large AI models like Stable Diffusion XL may not fit in a single GPU.
Model sharding solves this by:
- Splitting the model layers across devices
- Running each device's layers in sequence
- Passing intermediate results between devices
Example: UNet with 24 layers on 4 GPUs:
- GPU 0: Layers 1-6
- GPU 1: Layers 7-12
- GPU 2: Layers 13-18
- GPU 3: Layers 19-24
Data flows: Input -> GPU0 -> GPU1 -> GPU2 -> GPU3 -> Output
Usage:
var shard = new ModelShard<float>(layers, numDevices: 4);
var output = shard.Forward(input);
Constructors
ModelShard(IEnumerable<ILayer<T>>, int, ShardingConfig?)
Initializes model sharding across specified number of devices.
public ModelShard(IEnumerable<ILayer<T>> layers, int numDevices, ShardingConfig? config = null)
Parameters
layersIEnumerable<ILayer<T>>Layers to shard.
numDevicesintNumber of devices to use.
configShardingConfigOptional sharding configuration.
Remarks
Layers are distributed evenly by default. Use ShardingConfig for custom distribution based on memory requirements or compute costs.
Properties
Config
Sharding configuration.
public ShardingConfig Config { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs backward pass through all sharded layers.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient from subsequent layer.
Returns
- Tensor<T>
Gradient with respect to input.
Forward(Tensor<T>)
Performs forward pass through all sharded layers.
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor.
Returns
- Tensor<T>
Output tensor after all layers.
Remarks
Data flows sequentially through devices. Each device processes its assigned layers before passing results to the next device.
Forward(Tensor<T>, Tensor<T>?)
Performs forward pass with context (for conditional generation).
public Tensor<T> Forward(Tensor<T> input, Tensor<T>? context)
Parameters
inputTensor<T>Input tensor.
contextTensor<T>Context tensor (e.g., timestep, conditioning).
Returns
- Tensor<T>
Output tensor.
GetDeviceLayers(int)
Gets layers assigned to a specific device.
public IReadOnlyList<ILayer<T>> GetDeviceLayers(int device)
Parameters
deviceint
Returns
- IReadOnlyList<ILayer<T>>
GetDeviceMemoryUsage()
Gets memory usage per device.
public IReadOnlyDictionary<int, long> GetDeviceMemoryUsage()
Returns
GetLayerDevice(ILayer<T>)
Gets the device assignment for a layer.
public int GetLayerDevice(ILayer<T> layer)
Parameters
layerILayer<T>
Returns
ToString()
Gets a summary of the sharding distribution.
public override string ToString()
Returns
UpdateParameters(T)
Updates parameters on all devices.
public void UpdateParameters(T learningRate)
Parameters
learningRateTLearning rate for update.