Class GraphDataLoaderBase<T>
Abstract base class for graph data loaders providing common graph-related functionality.
public abstract class GraphDataLoaderBase<T> : DataLoaderBase<T>, IGraphDataLoader<T>, IDataLoader<T>, IResettable, ICountable, IBatchIterable<GraphData<T>>
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphDataLoaderBase<T>
- Implements
-
IDataLoader<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
GraphDataLoaderBase provides shared implementation for all graph data loaders including: - Node feature and adjacency matrix management - Task creation (node classification, graph classification, link prediction) - Train/validation/test mask generation - Batch iteration for multiple graphs
For Beginners: This base class handles common graph operations: - Storing node features and edge connections - Creating different types of tasks (node classification, link prediction) - Splitting data for training and evaluation
Concrete implementations (CitationNetworkLoader, MolecularDatasetLoader) extend this to load specific graph datasets.
Fields
LoadedGraphData
Storage for loaded graph data.
protected GraphData<T>? LoadedGraphData
Field Value
- GraphData<T>
LoadedGraphs
Storage for multiple graphs (for graph classification datasets).
protected List<GraphData<T>>? LoadedGraphs
Field Value
NumOps
Numeric operations helper for type T.
protected static readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Properties
AdjacencyMatrix
Gets the adjacency matrix of shape [numNodes, numNodes].
public Tensor<T> AdjacencyMatrix { get; }
Property Value
- Tensor<T>
BatchSize
Gets or sets the batch size for iteration.
public override int BatchSize { get; set; }
Property Value
EdgeIndex
Gets the edge index tensor in COO format [numEdges, 2].
public Tensor<T> EdgeIndex { get; }
Property Value
- Tensor<T>
GraphLabels
Gets graph labels for graph classification tasks, or null if not available.
public Tensor<T>? GraphLabels { get; protected set; }
Property Value
- Tensor<T>
HasNext
Gets whether there are more batches available in the current iteration.
public bool HasNext { get; }
Property Value
NodeFeatures
Gets the node feature tensor of shape [numNodes, numFeatures].
public Tensor<T> NodeFeatures { get; }
Property Value
- Tensor<T>
NodeLabels
Gets node labels for node classification tasks, or null if not available.
public Tensor<T>? NodeLabels { get; }
Property Value
- Tensor<T>
NumClasses
Gets the number of classes for classification tasks.
public abstract int NumClasses { get; }
Property Value
NumEdges
Gets the number of edges in the graph (or total across all graphs).
public int NumEdges { get; }
Property Value
NumGraphs
Gets the number of graphs in the dataset (1 for single-graph datasets like citation networks).
public virtual int NumGraphs { get; }
Property Value
NumNodeFeatures
Gets the number of node features.
public int NumNodeFeatures { get; }
Property Value
NumNodes
Gets the number of nodes in the graph (or total across all graphs).
public int NumNodes { get; }
Property Value
TotalCount
Gets the total number of samples in the dataset.
public override int TotalCount { get; }
Property Value
Methods
CreateGraphClassificationTask(double, double, int?)
Creates a graph classification task for datasets with multiple graphs.
public virtual GraphClassificationTask<T> CreateGraphClassificationTask(double trainRatio = 0.8, double valRatio = 0.1, int? seed = null)
Parameters
Returns
CreateLinkPredictionTask(double, double, int?)
Creates a link prediction task for predicting missing edges.
public virtual LinkPredictionTask<T> CreateLinkPredictionTask(double trainRatio = 0.85, double negativeRatio = 1, int? seed = null)
Parameters
Returns
CreateNodeClassificationTask(double, double, int?)
Creates a node classification task with train/val/test split.
public virtual NodeClassificationTask<T> CreateNodeClassificationTask(double trainRatio = 0.1, double valRatio = 0.1, int? seed = null)
Parameters
Returns
GetBatches(int?, bool, bool, int?)
Iterates through all batches in the dataset using lazy evaluation.
public virtual IEnumerable<GraphData<T>> GetBatches(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null)
Parameters
batchSizeint?Optional batch size override. Uses default BatchSize if null.
shuffleboolWhether to shuffle data before batching. Default is true.
dropLastboolWhether to drop the last incomplete batch. Default is false.
seedint?Optional random seed for reproducible shuffling.
Returns
- IEnumerable<GraphData<T>>
An enumerable sequence of batches using yield return for memory efficiency.
Remarks
This method provides a PyTorch-style iteration pattern using IEnumerable and yield return for memory-efficient lazy evaluation. Each call creates a fresh iteration, automatically handling reset and shuffle operations.
For Beginners: This is the recommended way to iterate through your data:
foreach (var (xBatch, yBatch) in dataLoader.GetBatches(batchSize: 32, shuffle: true))
{
// Train on this batch
model.TrainOnBatch(xBatch, yBatch);
}
Unlike GetNextBatch(), you don't need to call Reset() - each GetBatches() call starts fresh. The yield return pattern means batches are generated on-demand, not all loaded into memory at once.
GetBatchesAsync(int?, bool, bool, int?, int, CancellationToken)
Asynchronously iterates through all batches with prefetching support.
public virtual IAsyncEnumerable<GraphData<T>> GetBatchesAsync(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null, int prefetchCount = 2, CancellationToken cancellationToken = default)
Parameters
batchSizeint?Optional batch size override. Uses default BatchSize if null.
shuffleboolWhether to shuffle data before batching. Default is true.
dropLastboolWhether to drop the last incomplete batch. Default is false.
seedint?Optional random seed for reproducible shuffling.
prefetchCountintNumber of batches to prefetch ahead. Default is 2.
cancellationTokenCancellationTokenToken to cancel the iteration.
Returns
- IAsyncEnumerable<GraphData<T>>
An async enumerable sequence of batches.
Remarks
This method enables async batch iteration with configurable prefetching, similar to PyTorch's num_workers or TensorFlow's prefetch(). Batches are prepared in the background while the current batch is being processed.
For Beginners: Use this for large datasets or when batch preparation is slow:
await foreach (var (xBatch, yBatch) in dataLoader.GetBatchesAsync(prefetchCount: 2))
{
// While training on this batch, the next 2 batches are being prepared
await model.TrainOnBatchAsync(xBatch, yBatch);
}
Prefetching helps hide data loading latency, especially useful for:
- Large images that need decoding
- Data that requires preprocessing
- Slow storage (network drives, cloud storage)
GetNextBatch()
Gets the next batch of data.
public GraphData<T> GetNextBatch()
Returns
- GraphData<T>
The next batch of data.
Exceptions
- InvalidOperationException
Thrown when no more batches are available.
OnReset()
Called after Reset() to allow derived classes to perform additional reset operations.
protected override void OnReset()
Remarks
Override this to reset any domain-specific state. The base indices are already reset when this is called.
TryGetNextBatch(out GraphData<T>)
Attempts to get the next batch without throwing if unavailable.
public bool TryGetNextBatch(out GraphData<T> batch)
Parameters
batchGraphData<T>The batch if available, default otherwise.
Returns
- bool
True if a batch was available, false if iteration is complete.
Remarks
When false is returned, batch contains the default value for TBatch. Callers should check the return value before using batch.