Table of Contents

Class GraphDataLoaderBase<T>

Namespace
AiDotNet.Data.Loaders
Assembly
AiDotNet.dll

Abstract base class for graph data loaders providing common graph-related functionality.

public abstract class GraphDataLoaderBase<T> : DataLoaderBase<T>, IGraphDataLoader<T>, IDataLoader<T>, IResettable, ICountable, IBatchIterable<GraphData<T>>

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
GraphDataLoaderBase<T>
Implements
Derived
Inherited Members
Extension Methods

Remarks

GraphDataLoaderBase provides shared implementation for all graph data loaders including: - Node feature and adjacency matrix management - Task creation (node classification, graph classification, link prediction) - Train/validation/test mask generation - Batch iteration for multiple graphs

For Beginners: This base class handles common graph operations: - Storing node features and edge connections - Creating different types of tasks (node classification, link prediction) - Splitting data for training and evaluation

Concrete implementations (CitationNetworkLoader, MolecularDatasetLoader) extend this to load specific graph datasets.

Fields

LoadedGraphData

Storage for loaded graph data.

protected GraphData<T>? LoadedGraphData

Field Value

GraphData<T>

LoadedGraphs

Storage for multiple graphs (for graph classification datasets).

protected List<GraphData<T>>? LoadedGraphs

Field Value

List<GraphData<T>>

NumOps

Numeric operations helper for type T.

protected static readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

AdjacencyMatrix

Gets the adjacency matrix of shape [numNodes, numNodes].

public Tensor<T> AdjacencyMatrix { get; }

Property Value

Tensor<T>

BatchSize

Gets or sets the batch size for iteration.

public override int BatchSize { get; set; }

Property Value

int

EdgeIndex

Gets the edge index tensor in COO format [numEdges, 2].

public Tensor<T> EdgeIndex { get; }

Property Value

Tensor<T>

GraphLabels

Gets graph labels for graph classification tasks, or null if not available.

public Tensor<T>? GraphLabels { get; protected set; }

Property Value

Tensor<T>

HasNext

Gets whether there are more batches available in the current iteration.

public bool HasNext { get; }

Property Value

bool

NodeFeatures

Gets the node feature tensor of shape [numNodes, numFeatures].

public Tensor<T> NodeFeatures { get; }

Property Value

Tensor<T>

NodeLabels

Gets node labels for node classification tasks, or null if not available.

public Tensor<T>? NodeLabels { get; }

Property Value

Tensor<T>

NumClasses

Gets the number of classes for classification tasks.

public abstract int NumClasses { get; }

Property Value

int

NumEdges

Gets the number of edges in the graph (or total across all graphs).

public int NumEdges { get; }

Property Value

int

NumGraphs

Gets the number of graphs in the dataset (1 for single-graph datasets like citation networks).

public virtual int NumGraphs { get; }

Property Value

int

NumNodeFeatures

Gets the number of node features.

public int NumNodeFeatures { get; }

Property Value

int

NumNodes

Gets the number of nodes in the graph (or total across all graphs).

public int NumNodes { get; }

Property Value

int

TotalCount

Gets the total number of samples in the dataset.

public override int TotalCount { get; }

Property Value

int

Methods

CreateGraphClassificationTask(double, double, int?)

Creates a graph classification task for datasets with multiple graphs.

public virtual GraphClassificationTask<T> CreateGraphClassificationTask(double trainRatio = 0.8, double valRatio = 0.1, int? seed = null)

Parameters

trainRatio double
valRatio double
seed int?

Returns

GraphClassificationTask<T>

CreateLinkPredictionTask(double, double, int?)

Creates a link prediction task for predicting missing edges.

public virtual LinkPredictionTask<T> CreateLinkPredictionTask(double trainRatio = 0.85, double negativeRatio = 1, int? seed = null)

Parameters

trainRatio double
negativeRatio double
seed int?

Returns

LinkPredictionTask<T>

CreateNodeClassificationTask(double, double, int?)

Creates a node classification task with train/val/test split.

public virtual NodeClassificationTask<T> CreateNodeClassificationTask(double trainRatio = 0.1, double valRatio = 0.1, int? seed = null)

Parameters

trainRatio double
valRatio double
seed int?

Returns

NodeClassificationTask<T>

GetBatches(int?, bool, bool, int?)

Iterates through all batches in the dataset using lazy evaluation.

public virtual IEnumerable<GraphData<T>> GetBatches(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null)

Parameters

batchSize int?

Optional batch size override. Uses default BatchSize if null.

shuffle bool

Whether to shuffle data before batching. Default is true.

dropLast bool

Whether to drop the last incomplete batch. Default is false.

seed int?

Optional random seed for reproducible shuffling.

Returns

IEnumerable<GraphData<T>>

An enumerable sequence of batches using yield return for memory efficiency.

Remarks

This method provides a PyTorch-style iteration pattern using IEnumerable and yield return for memory-efficient lazy evaluation. Each call creates a fresh iteration, automatically handling reset and shuffle operations.

For Beginners: This is the recommended way to iterate through your data:

foreach (var (xBatch, yBatch) in dataLoader.GetBatches(batchSize: 32, shuffle: true))
{
    // Train on this batch
    model.TrainOnBatch(xBatch, yBatch);
}

Unlike GetNextBatch(), you don't need to call Reset() - each GetBatches() call starts fresh. The yield return pattern means batches are generated on-demand, not all loaded into memory at once.

GetBatchesAsync(int?, bool, bool, int?, int, CancellationToken)

Asynchronously iterates through all batches with prefetching support.

public virtual IAsyncEnumerable<GraphData<T>> GetBatchesAsync(int? batchSize = null, bool shuffle = true, bool dropLast = false, int? seed = null, int prefetchCount = 2, CancellationToken cancellationToken = default)

Parameters

batchSize int?

Optional batch size override. Uses default BatchSize if null.

shuffle bool

Whether to shuffle data before batching. Default is true.

dropLast bool

Whether to drop the last incomplete batch. Default is false.

seed int?

Optional random seed for reproducible shuffling.

prefetchCount int

Number of batches to prefetch ahead. Default is 2.

cancellationToken CancellationToken

Token to cancel the iteration.

Returns

IAsyncEnumerable<GraphData<T>>

An async enumerable sequence of batches.

Remarks

This method enables async batch iteration with configurable prefetching, similar to PyTorch's num_workers or TensorFlow's prefetch(). Batches are prepared in the background while the current batch is being processed.

For Beginners: Use this for large datasets or when batch preparation is slow:

await foreach (var (xBatch, yBatch) in dataLoader.GetBatchesAsync(prefetchCount: 2))
{
    // While training on this batch, the next 2 batches are being prepared
    await model.TrainOnBatchAsync(xBatch, yBatch);
}

Prefetching helps hide data loading latency, especially useful for:

  • Large images that need decoding
  • Data that requires preprocessing
  • Slow storage (network drives, cloud storage)

GetNextBatch()

Gets the next batch of data.

public GraphData<T> GetNextBatch()

Returns

GraphData<T>

The next batch of data.

Exceptions

InvalidOperationException

Thrown when no more batches are available.

OnReset()

Called after Reset() to allow derived classes to perform additional reset operations.

protected override void OnReset()

Remarks

Override this to reset any domain-specific state. The base indices are already reset when this is called.

TryGetNextBatch(out GraphData<T>)

Attempts to get the next batch without throwing if unavailable.

public bool TryGetNextBatch(out GraphData<T> batch)

Parameters

batch GraphData<T>

The batch if available, default otherwise.

Returns

bool

True if a batch was available, false if iteration is complete.

Remarks

When false is returned, batch contains the default value for TBatch. Callers should check the return value before using batch.