Class ParallelBatchLoader<TBatch>
Provides parallel batch loading with multiple workers for improved throughput.
public class ParallelBatchLoader<TBatch> : IDisposable
Type Parameters
TBatchThe type of batch produced.
- Inheritance
-
ParallelBatchLoader<TBatch>
- Implements
- Inherited Members
Remarks
ParallelBatchLoader uses multiple worker threads to prepare batches in parallel, similar to PyTorch's DataLoader with num_workers > 0. This can significantly improve training throughput when batch preparation is CPU-bound.
For Beginners: Training neural networks often involves: 1. Loading data from disk 2. Preprocessing (augmentation, normalization) 3. GPU training
With single-threaded loading, the GPU waits while data is prepared. With parallel loading, multiple workers prepare batches simultaneously, keeping the GPU constantly fed with data.
Example:
var parallelLoader = new ParallelBatchLoader<(Matrix<float>, Vector<float>)>(
batchProvider: dataLoader.GetBatches(32),
numWorkers: 4,
prefetchCount: 2
);
await foreach (var batch in parallelLoader.GetBatchesAsync())
{
await model.TrainOnBatchAsync(batch);
}
Constructors
ParallelBatchLoader(Func<IEnumerable<int>>, Func<int[], TBatch>, int, int?, int?)
Initializes a new instance of the ParallelBatchLoader class.
public ParallelBatchLoader(Func<IEnumerable<int>> indexProvider, Func<int[], TBatch> batchFactory, int batchSize, int? numWorkers = null, int? prefetchCount = null)
Parameters
indexProviderFunc<IEnumerable<int>>Function that provides indices for each epoch.
batchFactoryFunc<int[], TBatch>Function that creates a batch from an array of sample indices. The factory receives all indices for a single batch and should aggregate them into a single batch (e.g., by stacking samples into a matrix).
batchSizeintNumber of samples per batch.
numWorkersint?Number of parallel workers. Default is processor count.
prefetchCountint?Number of batches to prefetch. Default is 2 * numWorkers.
Remarks
The batchFactory function is responsible for:
- Loading samples at the given indices from the dataset
- Aggregating them into a single batch structure
- Applying any preprocessing or augmentation
Example batchFactory for Matrix/Vector data:
batchFactory: indices => {
var xBatch = new Matrix<float>(indices.Length, numFeatures);
var yBatch = new Vector<float>(indices.Length);
for (int i = 0; i < indices.Length; i++) {
int idx = indices[i];
for (int j = 0; j < numFeatures; j++)
xBatch[i, j] = dataset.X[idx, j];
yBatch[i] = dataset.Y[idx];
}
return (xBatch, yBatch);
}
Properties
NumWorkers
Gets the number of parallel workers.
public int NumWorkers { get; }
Property Value
PrefetchCount
Gets the prefetch count.
public int PrefetchCount { get; }
Property Value
Methods
Dispose()
Disposes the parallel batch loader.
public void Dispose()
GetBatchesAsync(CancellationToken)
Iterates through batches using parallel workers.
public IAsyncEnumerable<TBatch> GetBatchesAsync(CancellationToken cancellationToken = default)
Parameters
cancellationTokenCancellationTokenCancellation token.
Returns
- IAsyncEnumerable<TBatch>
Async enumerable of batches.