Table of Contents

Class GraphSAGENetwork<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a GraphSAGE (Graph Sample and Aggregate) Network for inductive learning on graphs.

public class GraphSAGENetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
GraphSAGENetwork<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

GraphSAGE, introduced by Hamilton et al. (2017), is designed for inductive learning on graphs. Unlike transductive methods that require all nodes during training, GraphSAGE learns aggregator functions that can generalize to completely unseen nodes and even new graphs.

For Beginners: GraphSAGE learns how to combine neighbor information.

How it works:

  • For each node, sample its neighbors
  • Aggregate neighbor features using a learnable function
  • Combine with node's own features
  • Result: new representation that captures local structure

Example - Social Network Recommendations:

  • New user joins the platform (unseen during training)
  • GraphSAGE can still make recommendations by:
    1. Looking at the new user's connections
    2. Aggregating features from those connections
    3. Generating a representation for the new user

Key Features:

  • Inductive: Can generalize to new, unseen nodes
  • Scalable: Uses sampling, not full neighborhoods
  • Flexible aggregators: Mean, MaxPool, or Sum
  • L2 normalization: Optional for stable training

Aggregator Types:

  • Mean: Average of neighbor features (most common)
  • MaxPool: Element-wise max (captures salient features)
  • Sum: Sum of neighbor features (preserves structure)

Architecture:

  1. Multiple GraphSAGE layers with different aggregators
  2. Optional L2 normalization between layers
  3. Final classification or regression head

When to use GraphSAGE:

  • When new nodes appear frequently (evolving graphs)
  • When you need to generalize to new graphs
  • For large-scale graphs where full-batch training is infeasible
  • Social networks, recommendation systems, dynamic graphs

Constructors

GraphSAGENetwork(NeuralNetworkArchitecture<T>, SAGEAggregatorType, int, bool, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)

Initializes a new instance of the GraphSAGENetwork<T> class with specified architecture.

public GraphSAGENetwork(NeuralNetworkArchitecture<T> architecture, SAGEAggregatorType aggregatorType = SAGEAggregatorType.Mean, int numLayers = 2, bool normalize = true, double dropoutRate = 0, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture defining the structure of the network.

aggregatorType SAGEAggregatorType

Type of aggregation function (default: Mean). Used only when creating default layers.

numLayers int

Number of GraphSAGE layers (default: 2). Used only when creating default layers.

normalize bool

Whether to apply L2 normalization (default: true). Used only when creating default layers.

dropoutRate double

Dropout rate applied during training (default: 0.0).

optimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for training.

lossFunction ILossFunction<T>

Optional loss function for training.

maxGradNorm double

Maximum gradient norm for clipping (default: 1.0).

Remarks

For Beginners: Creating a GraphSAGE network:

// Create architecture for node classification
var architecture = new NeuralNetworkArchitecture<double>(
    InputType.OneDimensional,
    NeuralNetworkTaskType.MultiClassClassification,
    NetworkComplexity.Simple,
    inputSize: 128,     // User profile features
    outputSize: 5);     // 5 user categories

// Create GraphSAGE with default layers
var sage = new GraphSAGENetwork<double>(architecture);

// Or create with custom layers by adding them to architecture:
architecture.Layers.Add(new GraphSAGELayer<double>(...));
var sageCustom = new GraphSAGENetwork<double>(architecture);

// Train on graph data
sage.TrainOnGraph(nodeFeatures, adjacencyMatrix, labels, epochs: 200);

Properties

AggregatorType

Gets the aggregator type used in GraphSAGE layers.

public SAGEAggregatorType AggregatorType { get; }

Property Value

SAGEAggregatorType

DropoutRate

Gets the dropout rate applied during training.

public double DropoutRate { get; }

Property Value

double

HiddenDim

Gets the hidden dimension size for each layer.

public int HiddenDim { get; }

Property Value

int

IsLoRAEnabled

Gets whether LoRA fine-tuning is currently enabled.

public bool IsLoRAEnabled { get; }

Property Value

bool

LoRARank

Gets the LoRA rank when LoRA is enabled.

public int LoRARank { get; }

Property Value

int

Normalize

Gets whether L2 normalization is applied after each layer.

public bool Normalize { get; }

Property Value

bool

NumLayers

Gets the number of GraphSAGE layers in the network.

public int NumLayers { get; }

Property Value

int

Methods

Backward(Tensor<T>)

Performs a backward pass through the network to calculate gradients.

public Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

The gradient of the loss with respect to the network's output.

Returns

Tensor<T>

The gradient of the loss with respect to the network's input.

CreateNewInstance()

Creates a new instance of this network type.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data from a binary reader.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

DisableLoRA()

Disables LoRA fine-tuning and restores original layers.

public void DisableLoRA()

EnableLoRAFineTuning(int, double, bool)

Enables LoRA fine-tuning for parameter-efficient training.

public void EnableLoRAFineTuning(int rank = 8, double alpha = -1, bool freezeBaseLayers = true)

Parameters

rank int
alpha double
freezeBaseLayers bool

Evaluate(Tensor<T>, Tensor<T>, Tensor<T>, bool[])

Evaluates the model on test data and returns accuracy.

public double Evaluate(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[] testMask)

Parameters

nodeFeatures Tensor<T>

Node feature tensor.

adjacencyMatrix Tensor<T>

Adjacency matrix.

labels Tensor<T>

Ground truth labels.

testMask bool[]

Boolean mask for test nodes.

Returns

double

Classification accuracy on test nodes.

Forward(Tensor<T>, Tensor<T>)

Performs a forward pass through the network with node features and adjacency matrix.

public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)

Parameters

nodeFeatures Tensor<T>

Node feature tensor of shape [batchSize, numNodes, inputFeatures] or [numNodes, inputFeatures].

adjacencyMatrix Tensor<T>

Adjacency matrix of shape [batchSize, numNodes, numNodes] or [numNodes, numNodes].

Returns

Tensor<T>

The output tensor after processing through all layers.

GetLoRAParameterCount()

Gets the number of trainable LoRA parameters.

public int GetLoRAParameterCount()

Returns

int

GetModelMetadata()

Gets metadata about this model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetNodeEmbeddings(Tensor<T>, Tensor<T>)

Generates node embeddings using the trained network.

public Tensor<T> GetNodeEmbeddings(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)

Parameters

nodeFeatures Tensor<T>

Node feature tensor.

adjacencyMatrix Tensor<T>

Adjacency matrix.

Returns

Tensor<T>

Node embedding tensor from the second-to-last layer.

Remarks

For Beginners: Node embeddings are useful for:

  • Clustering: Group similar nodes together
  • Visualization: Plot nodes in 2D/3D using t-SNE or UMAP
  • Transfer learning: Use embeddings as features for other tasks
  • Similarity search: Find similar nodes efficiently

GetParameterCount()

Gets the total number of trainable parameters in the network.

public int GetParameterCount()

Returns

int

GetParameters()

Gets all parameters as a vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

InitializeLayers()

Initializes the layers of the neural network based on the provided architecture.

protected override void InitializeLayers()

MergeLoRAWeights()

Merges LoRA weights into the base layers and disables LoRA mode.

public void MergeLoRAWeights()

Predict(Tensor<T>)

Makes a prediction using the trained network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

Returns

Tensor<T>

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data to a binary writer.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

SetAdjacencyMatrix(Tensor<T>)

Sets the adjacency matrix for graph operations.

public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)

Parameters

adjacencyMatrix Tensor<T>

Train(Tensor<T>, Tensor<T>)

Trains the network on a single batch of data.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>
expectedOutput Tensor<T>

TrainMiniBatch(Tensor<T>, Tensor<T>, Tensor<T>, int[], int, int, double, int)

Performs mini-batch training using neighbor sampling for scalability.

public void TrainMiniBatch(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, int[] trainIndices, int batchSize = 512, int epochs = 200, double learningRate = 0.01, int numSamples = 25)

Parameters

nodeFeatures Tensor<T>

Full node feature tensor.

adjacencyMatrix Tensor<T>

Full adjacency matrix.

labels Tensor<T>

Label tensor for supervised learning.

trainIndices int[]

Indices of training nodes.

batchSize int

Number of nodes per batch (default: 512).

epochs int

Number of training epochs (default: 200).

learningRate double

Learning rate (default: 0.01).

numSamples int

Number of neighbors to sample per layer (default: 25).

Remarks

For Beginners: Mini-batch training for large graphs:

For very large graphs, training on all nodes at once is infeasible. Mini-batch training with neighbor sampling makes GraphSAGE scalable:

How it works:

  1. Sample a batch of target nodes
  2. For each target, sample a subset of neighbors
  3. Compute representations only for sampled subgraph
  4. Update model parameters

Benefits:

  • Constant memory usage regardless of graph size
  • Can train on graphs with millions of nodes
  • Provides regularization through random sampling

TrainOnGraph(Tensor<T>, Tensor<T>, Tensor<T>, bool[]?, int, double)

Trains the GraphSAGE network on graph-structured data.

public void TrainOnGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[]? trainMask = null, int epochs = 200, double learningRate = 0.01)

Parameters

nodeFeatures Tensor<T>

Node feature tensor of shape [numNodes, inputFeatures].

adjacencyMatrix Tensor<T>

Adjacency matrix of shape [numNodes, numNodes].

labels Tensor<T>

Label tensor for supervised learning.

trainMask bool[]

Optional boolean mask indicating which nodes to train on.

epochs int

Number of training epochs (default: 200).

learningRate double

Learning rate for optimization (default: 0.01).

Remarks

For Beginners: Training GraphSAGE on graph data:

GraphSAGE learns to aggregate neighbor information through training. The aggregation functions become better at combining relevant features to produce informative node representations.

Training tips:

  • Use Mean aggregator for most tasks (stable, effective)
  • Use MaxPool for graphs where individual features are important
  • L2 normalization helps with training stability
  • Higher learning rates (0.01) often work well for GraphSAGE

UpdateParameters(Vector<T>)

Updates the parameters of all layers in the network.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for the network.