Table of Contents

Class GraphNeuralNetwork<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a Graph Neural Network that can process data represented as graphs.

public class GraphNeuralNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
GraphNeuralNetwork<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

A Graph Neural Network (GNN) is designed to work with data structured as graphs, where nodes represent entities and edges represent relationships between these entities. This implementation supports various activation functions for different layers and provides methods for predicting outputs from both vector inputs and graph inputs.

For Beginners: A Graph Neural Network is a type of neural network that works with connected data.

Think of it like analyzing a social network:

  • Each person is a "node" in the graph
  • Friendships between people are "edges" connecting the nodes
  • People have attributes (like age, location, interests) which are "node features"

GNNs are useful when:

  • The relationships between items are as important as the items themselves
  • You're working with network-like data (social networks, molecules, road systems)
  • You need to make predictions about how nodes influence each other

For example, GNNs can help predict which products a customer might like based on what similar customers have purchased, by analyzing the connections between customers and products.

Constructors

GraphNeuralNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?)

Initializes a new instance of the GraphNeuralNetwork<T> class with scalar activation functions.

public GraphNeuralNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IActivationFunction<T>? graphConvolutionalActivation = null, IActivationFunction<T>? activationLayerActivation = null, IActivationFunction<T>? finalDenseLayerActivation = null, IActivationFunction<T>? finalActivationLayerActivation = null)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture defining the structure of the network.

lossFunction ILossFunction<T>
graphConvolutionalActivation IActivationFunction<T>

The scalar activation function for graph convolutional layers.

activationLayerActivation IActivationFunction<T>

The scalar activation function for standard activation layers.

finalDenseLayerActivation IActivationFunction<T>

The scalar activation function for the final dense layer.

finalActivationLayerActivation IActivationFunction<T>

The scalar activation function for the final activation layer.

Remarks

This constructor creates a graph neural network with the specified architecture and scalar activation functions. Scalar activation functions operate on individual elements rather than entire vectors.

For Beginners: This creates a new graph neural network where the activation functions work on individual numbers separately.

When creating your network, you specify:

  • The overall structure (architecture)
  • Which activation functions to use at different stages

Scalar activation functions process each value independently, applying the same transformation to each number without considering relationships between values.

GraphNeuralNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?)

Initializes a new instance of the GraphNeuralNetwork<T> class with vector activation functions.

public GraphNeuralNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IVectorActivationFunction<T>? graphConvolutionalVectorActivation = null, IVectorActivationFunction<T>? activationLayerVectorActivation = null, IVectorActivationFunction<T>? finalDenseLayerVectorActivation = null, IVectorActivationFunction<T>? finalActivationLayerVectorActivation = null)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture defining the structure of the network.

lossFunction ILossFunction<T>
graphConvolutionalVectorActivation IVectorActivationFunction<T>

The vector activation function for graph convolutional layers.

activationLayerVectorActivation IVectorActivationFunction<T>

The vector activation function for standard activation layers.

finalDenseLayerVectorActivation IVectorActivationFunction<T>

The vector activation function for the final dense layer.

finalActivationLayerVectorActivation IVectorActivationFunction<T>

The vector activation function for the final activation layer.

Remarks

This constructor creates a graph neural network with the specified architecture and vector activation functions. Vector activation functions operate on entire vectors rather than individual elements.

For Beginners: This creates a new graph neural network where the activation functions work on groups of numbers together.

When creating your network, you specify:

  • The overall structure (architecture)
  • Which activation functions to use at different stages

Vector activation functions process multiple values as a group, which can help capture relationships between different values.

Properties

AuxiliaryLossWeight

Gets or sets the weight for the graph smoothness auxiliary loss.

public T AuxiliaryLossWeight { get; set; }

Property Value

T

Remarks

This weight controls how much the graph smoothness regularization contributes to the total loss. The total loss is: main_loss + (auxiliary_weight * smoothness_loss). Typical values range from 0.01 to 0.1.

For Beginners: This controls how much the network should enforce similarity between connected nodes.

The weight determines the balance between:

  • Task accuracy (main loss) - making correct predictions
  • Graph smoothness (auxiliary loss) - keeping connected nodes similar

Common values:

  • 0.05 (default): Balanced smoothness regularization
  • 0.01-0.03: Light smoothness enforcement
  • 0.08-0.1: Strong smoothness enforcement

Higher values make the network focus more on keeping connected nodes similar, which can help with generalization but may reduce flexibility.

UseAuxiliaryLoss

Gets or sets whether auxiliary loss (graph smoothness regularization) should be used during training.

public bool UseAuxiliaryLoss { get; set; }

Property Value

bool

Remarks

Graph smoothness regularization encourages connected nodes to have similar representations. This is based on the principle that nodes with edges between them should have similar features, which is a common assumption in many graph-based learning tasks.

For Beginners: Graph smoothness is like encouraging friends to be similar.

In a graph:

  • Nodes that are connected (like friends in a social network) should have similar features
  • This auxiliary loss penalizes the network when connected nodes have very different representations
  • It helps the network learn more meaningful patterns that respect the graph structure

For example:

  • In a social network, friends often have similar interests
  • In a molecule, bonded atoms influence each other's properties
  • In a citation network, papers that cite each other often cover similar topics

Enabling this helps the network learn representations that are consistent with the graph structure.

Methods

ComputeAuxiliaryLoss()

Computes the auxiliary loss for graph smoothness regularization.

public T ComputeAuxiliaryLoss()

Returns

T

The computed graph smoothness auxiliary loss.

Remarks

This method computes the graph smoothness loss, which encourages connected nodes to have similar representations. The loss is computed as the sum of squared differences between representations of connected nodes, weighted by the adjacency matrix. Formula: L_smooth = Σ_edges ||h_i - h_j||² * A_{ij}

For Beginners: This calculates how different connected nodes are from each other.

Graph smoothness works by:

  1. Looking at each pair of connected nodes in the graph
  2. Measuring how different their learned representations are
  3. Penalizing large differences between connected nodes
  4. Summing up these penalties across all edges

This helps because:

  • It encourages the network to respect the graph structure
  • Connected nodes learn similar representations
  • The network generalizes better to new graph data

For example, in a social network, friends (connected nodes) will have similar learned features, which makes sense since friends often share interests.

CreateNewInstance()

Creates a new instance of the same type as this neural network.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

A new instance of the same neural network type.

Remarks

For Beginners: This creates a blank version of the same type of neural network.

It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes Graph Neural Network-specific data from a binary reader.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

The BinaryReader to read the data from.

Remarks

This method reads the Graph Neural Network's specific configuration data from a binary stream. This includes activation function types and any other GNN-specific parameters. After reading this data, the GNN's state is fully restored to what it was when saved.

For Beginners: This method loads a previously saved GNN configuration.

Think of it like following a recipe to rebuild your neural network:

  • Reading what activation functions were used at different stages
  • Setting up the graph-specific components with the right configuration
  • Restoring any other special settings that make this GNN unique

This ensures that your loaded model will process information exactly the same way as when you saved it.

GetAuxiliaryLossDiagnostics()

Gets diagnostic information about the graph smoothness auxiliary loss.

public Dictionary<string, string> GetAuxiliaryLossDiagnostics()

Returns

Dictionary<string, string>

A dictionary containing diagnostic information about auxiliary losses.

Remarks

This method returns detailed diagnostics about the graph smoothness regularization, including the computed smoothness loss, weight applied, and whether the feature is enabled. This information is useful for monitoring training progress and debugging.

For Beginners: This provides information about how graph smoothness regularization is working.

The diagnostics include:

  • Total smoothness loss (how different connected nodes are)
  • Weight applied to the smoothness loss
  • Whether smoothness regularization is enabled
  • Whether node representations are being tracked

This helps you:

  • Monitor if smoothness regularization is helping training
  • Debug issues with graph structure learning
  • Understand the impact of smoothness enforcement on learning

You can use this information to adjust the auxiliary loss weight for better results.

GetDiagnostics()

Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.

public Dictionary<string, string> GetDiagnostics()

Returns

Dictionary<string, string>

A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().

GetModelMetadata()

Gets metadata about the Graph Neural Network model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetaData object containing information about the model.

Remarks

This method returns metadata about the Graph Neural Network, including its model type, number of layers, types of activation functions used, and serialized model data. This information is useful for model management and serialization.

For Beginners: This method provides a summary of your GNN's configuration.

The metadata includes:

  • The type of model (GraphNeuralNetwork)
  • Details about the network structure (layers, activations)
  • Performance metrics
  • Data needed to save and load the model

This is useful for:

  • Keeping track of different models you've created
  • Understanding a model's properties
  • Saving the model for later use

InitializeLayers()

Initializes the layers of the neural network based on the provided architecture.

protected override void InitializeLayers()

Remarks

This method sets up the layers of the graph neural network. If the architecture provides specific layers, those are used directly. Otherwise, default layers appropriate for a graph neural network are created.

For Beginners: This method sets up the building blocks of your neural network.

When initializing layers:

  • If you provided specific layers in the architecture, those are used
  • If not, the network creates a standard set of layers for graph processing

Think of this like assembling the components of the network before training begins.

Predict(Tensor<T>)

Performs a forward pass through the network to make a prediction using a standard input tensor.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor to process.

Returns

Tensor<T>

The output tensor containing the prediction.

Remarks

This method passes the input data through each layer of the network sequentially. For a GraphNeuralNetwork, this method assumes the input tensor already contains the necessary node features and adjacency information in a preprocessed format, or that the input is for a portion of the network that uses standard layers rather than graph-specific ones.

For Beginners: While GNNs work best with explicit graph data (provided through the PredictGraph method), this method allows the network to process pre-processed graph data or operate on non-graph portions of the network.

It works by:

  1. Starting with your input data
  2. Passing it through each layer of the network in sequence
  3. Letting each layer transform the data based on its specific function

The result is a prediction based on the trained network's understanding of graph patterns.

PredictGraph(Tensor<T>, Tensor<T>)

Performs a forward pass through the network to generate a prediction from graph data.

public Tensor<T> PredictGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)

Parameters

nodeFeatures Tensor<T>

A tensor containing features for each node in the graph.

adjacencyMatrix Tensor<T>

A tensor representing the connections between nodes in the graph.

Returns

Tensor<T>

A tensor containing the prediction for the graph data.

Remarks

This method processes graph data through the network. It takes node features and an adjacency matrix as input and passes them through graph-specific and standard layers, applying appropriate transformations at each step. The method concludes with hybrid pooling to generate the final output. If graph smoothness regularization is enabled, it caches the node representations for auxiliary loss computation.

For Beginners: This method processes a graph (like a social network) through the neural network to make predictions.

You provide two pieces of information:

  • nodeFeatures: Information about each node (like age, interests for each person)
  • adjacencyMatrix: Information about how nodes are connected (like who is friends with whom)

The method:

  • Passes this information through specialized graph layers
  • Also passes it through standard neural network layers
  • Combines the results using a technique called "hybrid pooling"
  • If smoothness regularization is enabled, saves intermediate representations
  • Returns a prediction based on the entire graph structure

This is useful for tasks like predicting which users might become friends, which products a customer might like, or how a molecule might behave.

Exceptions

ArgumentNullException

Thrown when either nodeFeatures or adjacencyMatrix is null.

ArgumentException

Thrown when nodeFeatures and adjacencyMatrix have incompatible dimensions.

InvalidOperationException

Thrown when the network encounters an unsupported layer type or invalid output shape.

SerializeNetworkSpecificData(BinaryWriter)

Serializes Graph Neural Network-specific data to a binary writer.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

The BinaryWriter to write the data to.

Remarks

This method writes the Graph Neural Network's specific configuration data to a binary stream. This includes activation function types and any other GNN-specific parameters. This data is needed to reconstruct the GNN when deserializing.

For Beginners: This method saves the special configuration of your GNN.

Think of it like writing down the recipe for your neural network:

  • What activation functions it uses at different stages
  • How its graph-specific components are configured
  • Any other special settings that make this GNN unique

These details are crucial because they define how your GNN processes information, and they need to be saved along with the weights for the model to work correctly when loaded later.

Train(Tensor<T>, Tensor<T>)

Trains the Graph Neural Network on a single input-output pair.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input tensor containing node features and adjacency information.

expectedOutput Tensor<T>

The expected output tensor for the given input.

Remarks

This method trains the network on a single batch of data. For graph neural networks, the input tensor needs to contain both node features and adjacency information in a structured format. This implementation assumes the input tensor contains node features in the first half and adjacency matrix in the second half.

For Beginners: This method teaches the neural network using example data.

For a graph neural network, the training process:

  1. Extracts node features and connection information from the input
  2. Makes a prediction using this graph data
  3. Compares the prediction to the expected output to calculate error
  4. Updates the network's internal values to reduce the error

Over time, with many examples, the network learns to make accurate predictions by understanding how nodes in a graph influence each other.

TrainGraph(Tensor<T>, Tensor<T>, Tensor<T>)

Trains the Graph Neural Network directly on graph data.

public void TrainGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> expectedOutput)

Parameters

nodeFeatures Tensor<T>

A tensor containing features for each node in the graph.

adjacencyMatrix Tensor<T>

A tensor representing the connections between nodes in the graph.

expectedOutput Tensor<T>

The expected output tensor for the given graph input.

Remarks

This method provides a more direct interface for training the network on graph data by explicitly accepting node features and adjacency matrix as separate parameters. This is often more intuitive than combining them into a single input tensor.

For Beginners: This is a more straightforward way to train your graph neural network.

Instead of combining node information and connection information into one input, you can provide them separately:

  • nodeFeatures: Information about each node (e.g., user profiles in a social network)
  • adjacencyMatrix: Information about connections (e.g., who is friends with whom)
  • expectedOutput: What the network should predict for this graph

The network then learns to make predictions based on both the node attributes and how nodes are connected to each other.

UpdateParameters(Vector<T>)

Updates the parameters of all layers in the network using the provided parameter vector.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing updated parameters for all layers.

Remarks

This method distributes the provided parameter values to each layer in the network. It extracts the appropriate segment of the parameter vector for each layer based on the layer's parameter count.

For Beginners: This method updates all the learned values in the network.

During training, a neural network adjusts its internal values (parameters) to make better predictions. This method:

  1. Takes a long list of new parameter values
  2. Figures out which values belong to which layers
  3. Updates each layer with its corresponding values

Think of it like updating the settings on different parts of a machine to make it work better based on what it has learned.