Class GraphSAGENetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a GraphSAGE (Graph Sample and Aggregate) Network for inductive learning on graphs.
public class GraphSAGENetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphSAGENetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
GraphSAGE, introduced by Hamilton et al. (2017), is designed for inductive learning on graphs. Unlike transductive methods that require all nodes during training, GraphSAGE learns aggregator functions that can generalize to completely unseen nodes and even new graphs.
For Beginners: GraphSAGE learns how to combine neighbor information.
How it works:
- For each node, sample its neighbors
- Aggregate neighbor features using a learnable function
- Combine with node's own features
- Result: new representation that captures local structure
Example - Social Network Recommendations:
- New user joins the platform (unseen during training)
- GraphSAGE can still make recommendations by:
- Looking at the new user's connections
- Aggregating features from those connections
- Generating a representation for the new user
Key Features:
- Inductive: Can generalize to new, unseen nodes
- Scalable: Uses sampling, not full neighborhoods
- Flexible aggregators: Mean, MaxPool, or Sum
- L2 normalization: Optional for stable training
Aggregator Types:
- Mean: Average of neighbor features (most common)
- MaxPool: Element-wise max (captures salient features)
- Sum: Sum of neighbor features (preserves structure)
Architecture:
- Multiple GraphSAGE layers with different aggregators
- Optional L2 normalization between layers
- Final classification or regression head
When to use GraphSAGE:
- When new nodes appear frequently (evolving graphs)
- When you need to generalize to new graphs
- For large-scale graphs where full-batch training is infeasible
- Social networks, recommendation systems, dynamic graphs
Constructors
GraphSAGENetwork(NeuralNetworkArchitecture<T>, SAGEAggregatorType, int, bool, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the GraphSAGENetwork<T> class with specified architecture.
public GraphSAGENetwork(NeuralNetworkArchitecture<T> architecture, SAGEAggregatorType aggregatorType = SAGEAggregatorType.Mean, int numLayers = 2, bool normalize = true, double dropoutRate = 0, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining the structure of the network.
aggregatorTypeSAGEAggregatorTypeType of aggregation function (default: Mean). Used only when creating default layers.
numLayersintNumber of GraphSAGE layers (default: 2). Used only when creating default layers.
normalizeboolWhether to apply L2 normalization (default: true). Used only when creating default layers.
dropoutRatedoubleDropout rate applied during training (default: 0.0).
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training.
lossFunctionILossFunction<T>Optional loss function for training.
maxGradNormdoubleMaximum gradient norm for clipping (default: 1.0).
Remarks
For Beginners: Creating a GraphSAGE network:
// Create architecture for node classification
var architecture = new NeuralNetworkArchitecture<double>(
InputType.OneDimensional,
NeuralNetworkTaskType.MultiClassClassification,
NetworkComplexity.Simple,
inputSize: 128, // User profile features
outputSize: 5); // 5 user categories
// Create GraphSAGE with default layers
var sage = new GraphSAGENetwork<double>(architecture);
// Or create with custom layers by adding them to architecture:
architecture.Layers.Add(new GraphSAGELayer<double>(...));
var sageCustom = new GraphSAGENetwork<double>(architecture);
// Train on graph data
sage.TrainOnGraph(nodeFeatures, adjacencyMatrix, labels, epochs: 200);
Properties
AggregatorType
Gets the aggregator type used in GraphSAGE layers.
public SAGEAggregatorType AggregatorType { get; }
Property Value
DropoutRate
Gets the dropout rate applied during training.
public double DropoutRate { get; }
Property Value
HiddenDim
Gets the hidden dimension size for each layer.
public int HiddenDim { get; }
Property Value
IsLoRAEnabled
Gets whether LoRA fine-tuning is currently enabled.
public bool IsLoRAEnabled { get; }
Property Value
LoRARank
Gets the LoRA rank when LoRA is enabled.
public int LoRARank { get; }
Property Value
Normalize
Gets whether L2 normalization is applied after each layer.
public bool Normalize { get; }
Property Value
NumLayers
Gets the number of GraphSAGE layers in the network.
public int NumLayers { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass through the network to calculate gradients.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the network's output.
Returns
- Tensor<T>
The gradient of the loss with respect to the network's input.
CreateNewInstance()
Creates a new instance of this network type.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
DisableLoRA()
Disables LoRA fine-tuning and restores original layers.
public void DisableLoRA()
EnableLoRAFineTuning(int, double, bool)
Enables LoRA fine-tuning for parameter-efficient training.
public void EnableLoRAFineTuning(int rank = 8, double alpha = -1, bool freezeBaseLayers = true)
Parameters
Evaluate(Tensor<T>, Tensor<T>, Tensor<T>, bool[])
Evaluates the model on test data and returns accuracy.
public double Evaluate(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[] testMask)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix.
labelsTensor<T>Ground truth labels.
testMaskbool[]Boolean mask for test nodes.
Returns
- double
Classification accuracy on test nodes.
Forward(Tensor<T>, Tensor<T>)
Performs a forward pass through the network with node features and adjacency matrix.
public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [batchSize, numNodes, inputFeatures] or [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [batchSize, numNodes, numNodes] or [numNodes, numNodes].
Returns
- Tensor<T>
The output tensor after processing through all layers.
GetLoRAParameterCount()
Gets the number of trainable LoRA parameters.
public int GetLoRAParameterCount()
Returns
GetModelMetadata()
Gets metadata about this model.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetNodeEmbeddings(Tensor<T>, Tensor<T>)
Generates node embeddings using the trained network.
public Tensor<T> GetNodeEmbeddings(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix.
Returns
- Tensor<T>
Node embedding tensor from the second-to-last layer.
Remarks
For Beginners: Node embeddings are useful for:
- Clustering: Group similar nodes together
- Visualization: Plot nodes in 2D/3D using t-SNE or UMAP
- Transfer learning: Use embeddings as features for other tasks
- Similarity search: Find similar nodes efficiently
GetParameterCount()
Gets the total number of trainable parameters in the network.
public int GetParameterCount()
Returns
GetParameters()
Gets all parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
MergeLoRAWeights()
Merges LoRA weights into the base layers and disables LoRA mode.
public void MergeLoRAWeights()
Predict(Tensor<T>)
Makes a prediction using the trained network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
SetAdjacencyMatrix(Tensor<T>)
Sets the adjacency matrix for graph operations.
public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)
Parameters
adjacencyMatrixTensor<T>
Train(Tensor<T>, Tensor<T>)
Trains the network on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>expectedOutputTensor<T>
TrainMiniBatch(Tensor<T>, Tensor<T>, Tensor<T>, int[], int, int, double, int)
Performs mini-batch training using neighbor sampling for scalability.
public void TrainMiniBatch(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, int[] trainIndices, int batchSize = 512, int epochs = 200, double learningRate = 0.01, int numSamples = 25)
Parameters
nodeFeaturesTensor<T>Full node feature tensor.
adjacencyMatrixTensor<T>Full adjacency matrix.
labelsTensor<T>Label tensor for supervised learning.
trainIndicesint[]Indices of training nodes.
batchSizeintNumber of nodes per batch (default: 512).
epochsintNumber of training epochs (default: 200).
learningRatedoubleLearning rate (default: 0.01).
numSamplesintNumber of neighbors to sample per layer (default: 25).
Remarks
For Beginners: Mini-batch training for large graphs:
For very large graphs, training on all nodes at once is infeasible. Mini-batch training with neighbor sampling makes GraphSAGE scalable:
How it works:
- Sample a batch of target nodes
- For each target, sample a subset of neighbors
- Compute representations only for sampled subgraph
- Update model parameters
Benefits:
- Constant memory usage regardless of graph size
- Can train on graphs with millions of nodes
- Provides regularization through random sampling
TrainOnGraph(Tensor<T>, Tensor<T>, Tensor<T>, bool[]?, int, double)
Trains the GraphSAGE network on graph-structured data.
public void TrainOnGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[]? trainMask = null, int epochs = 200, double learningRate = 0.01)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [numNodes, numNodes].
labelsTensor<T>Label tensor for supervised learning.
trainMaskbool[]Optional boolean mask indicating which nodes to train on.
epochsintNumber of training epochs (default: 200).
learningRatedoubleLearning rate for optimization (default: 0.01).
Remarks
For Beginners: Training GraphSAGE on graph data:
GraphSAGE learns to aggregate neighbor information through training. The aggregation functions become better at combining relevant features to produce informative node representations.
Training tips:
- Use Mean aggregator for most tasks (stable, effective)
- Use MaxPool for graphs where individual features are important
- L2 normalization helps with training stability
- Higher learning rates (0.01) often work well for GraphSAGE
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.