Class GraphNeuralNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Graph Neural Network that can process data represented as graphs.
public class GraphNeuralNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphNeuralNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
A Graph Neural Network (GNN) is designed to work with data structured as graphs, where nodes represent entities and edges represent relationships between these entities. This implementation supports various activation functions for different layers and provides methods for predicting outputs from both vector inputs and graph inputs.
For Beginners: A Graph Neural Network is a type of neural network that works with connected data.
Think of it like analyzing a social network:
- Each person is a "node" in the graph
- Friendships between people are "edges" connecting the nodes
- People have attributes (like age, location, interests) which are "node features"
GNNs are useful when:
- The relationships between items are as important as the items themselves
- You're working with network-like data (social networks, molecules, road systems)
- You need to make predictions about how nodes influence each other
For example, GNNs can help predict which products a customer might like based on what similar customers have purchased, by analyzing the connections between customers and products.
Constructors
GraphNeuralNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?)
Initializes a new instance of the GraphNeuralNetwork<T> class with scalar activation functions.
public GraphNeuralNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IActivationFunction<T>? graphConvolutionalActivation = null, IActivationFunction<T>? activationLayerActivation = null, IActivationFunction<T>? finalDenseLayerActivation = null, IActivationFunction<T>? finalActivationLayerActivation = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining the structure of the network.
lossFunctionILossFunction<T>graphConvolutionalActivationIActivationFunction<T>The scalar activation function for graph convolutional layers.
activationLayerActivationIActivationFunction<T>The scalar activation function for standard activation layers.
finalDenseLayerActivationIActivationFunction<T>The scalar activation function for the final dense layer.
finalActivationLayerActivationIActivationFunction<T>The scalar activation function for the final activation layer.
Remarks
This constructor creates a graph neural network with the specified architecture and scalar activation functions. Scalar activation functions operate on individual elements rather than entire vectors.
For Beginners: This creates a new graph neural network where the activation functions work on individual numbers separately.
When creating your network, you specify:
- The overall structure (architecture)
- Which activation functions to use at different stages
Scalar activation functions process each value independently, applying the same transformation to each number without considering relationships between values.
GraphNeuralNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?)
Initializes a new instance of the GraphNeuralNetwork<T> class with vector activation functions.
public GraphNeuralNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IVectorActivationFunction<T>? graphConvolutionalVectorActivation = null, IVectorActivationFunction<T>? activationLayerVectorActivation = null, IVectorActivationFunction<T>? finalDenseLayerVectorActivation = null, IVectorActivationFunction<T>? finalActivationLayerVectorActivation = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining the structure of the network.
lossFunctionILossFunction<T>graphConvolutionalVectorActivationIVectorActivationFunction<T>The vector activation function for graph convolutional layers.
activationLayerVectorActivationIVectorActivationFunction<T>The vector activation function for standard activation layers.
finalDenseLayerVectorActivationIVectorActivationFunction<T>The vector activation function for the final dense layer.
finalActivationLayerVectorActivationIVectorActivationFunction<T>The vector activation function for the final activation layer.
Remarks
This constructor creates a graph neural network with the specified architecture and vector activation functions. Vector activation functions operate on entire vectors rather than individual elements.
For Beginners: This creates a new graph neural network where the activation functions work on groups of numbers together.
When creating your network, you specify:
- The overall structure (architecture)
- Which activation functions to use at different stages
Vector activation functions process multiple values as a group, which can help capture relationships between different values.
Properties
AuxiliaryLossWeight
Gets or sets the weight for the graph smoothness auxiliary loss.
public T AuxiliaryLossWeight { get; set; }
Property Value
- T
Remarks
This weight controls how much the graph smoothness regularization contributes to the total loss. The total loss is: main_loss + (auxiliary_weight * smoothness_loss). Typical values range from 0.01 to 0.1.
For Beginners: This controls how much the network should enforce similarity between connected nodes.
The weight determines the balance between:
- Task accuracy (main loss) - making correct predictions
- Graph smoothness (auxiliary loss) - keeping connected nodes similar
Common values:
- 0.05 (default): Balanced smoothness regularization
- 0.01-0.03: Light smoothness enforcement
- 0.08-0.1: Strong smoothness enforcement
Higher values make the network focus more on keeping connected nodes similar, which can help with generalization but may reduce flexibility.
UseAuxiliaryLoss
Gets or sets whether auxiliary loss (graph smoothness regularization) should be used during training.
public bool UseAuxiliaryLoss { get; set; }
Property Value
Remarks
Graph smoothness regularization encourages connected nodes to have similar representations. This is based on the principle that nodes with edges between them should have similar features, which is a common assumption in many graph-based learning tasks.
For Beginners: Graph smoothness is like encouraging friends to be similar.
In a graph:
- Nodes that are connected (like friends in a social network) should have similar features
- This auxiliary loss penalizes the network when connected nodes have very different representations
- It helps the network learn more meaningful patterns that respect the graph structure
For example:
- In a social network, friends often have similar interests
- In a molecule, bonded atoms influence each other's properties
- In a citation network, papers that cite each other often cover similar topics
Enabling this helps the network learn representations that are consistent with the graph structure.
Methods
ComputeAuxiliaryLoss()
Computes the auxiliary loss for graph smoothness regularization.
public T ComputeAuxiliaryLoss()
Returns
- T
The computed graph smoothness auxiliary loss.
Remarks
This method computes the graph smoothness loss, which encourages connected nodes to have similar representations. The loss is computed as the sum of squared differences between representations of connected nodes, weighted by the adjacency matrix. Formula: L_smooth = Σ_edges ||h_i - h_j||² * A_{ij}
For Beginners: This calculates how different connected nodes are from each other.
Graph smoothness works by:
- Looking at each pair of connected nodes in the graph
- Measuring how different their learned representations are
- Penalizing large differences between connected nodes
- Summing up these penalties across all edges
This helps because:
- It encourages the network to respect the graph structure
- Connected nodes learn similar representations
- The network generalizes better to new graph data
For example, in a social network, friends (connected nodes) will have similar learned features, which makes sense since friends often share interests.
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes Graph Neural Network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method reads the Graph Neural Network's specific configuration data from a binary stream. This includes activation function types and any other GNN-specific parameters. After reading this data, the GNN's state is fully restored to what it was when saved.
For Beginners: This method loads a previously saved GNN configuration.
Think of it like following a recipe to rebuild your neural network:
- Reading what activation functions were used at different stages
- Setting up the graph-specific components with the right configuration
- Restoring any other special settings that make this GNN unique
This ensures that your loaded model will process information exactly the same way as when you saved it.
GetAuxiliaryLossDiagnostics()
Gets diagnostic information about the graph smoothness auxiliary loss.
public Dictionary<string, string> GetAuxiliaryLossDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic information about auxiliary losses.
Remarks
This method returns detailed diagnostics about the graph smoothness regularization, including the computed smoothness loss, weight applied, and whether the feature is enabled. This information is useful for monitoring training progress and debugging.
For Beginners: This provides information about how graph smoothness regularization is working.
The diagnostics include:
- Total smoothness loss (how different connected nodes are)
- Weight applied to the smoothness loss
- Whether smoothness regularization is enabled
- Whether node representations are being tracked
This helps you:
- Monitor if smoothness regularization is helping training
- Debug issues with graph structure learning
- Understand the impact of smoothness enforcement on learning
You can use this information to adjust the auxiliary loss weight for better results.
GetDiagnostics()
Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.
public Dictionary<string, string> GetDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().
GetModelMetadata()
Gets metadata about the Graph Neural Network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
Remarks
This method returns metadata about the Graph Neural Network, including its model type, number of layers, types of activation functions used, and serialized model data. This information is useful for model management and serialization.
For Beginners: This method provides a summary of your GNN's configuration.
The metadata includes:
- The type of model (GraphNeuralNetwork)
- Details about the network structure (layers, activations)
- Performance metrics
- Data needed to save and load the model
This is useful for:
- Keeping track of different models you've created
- Understanding a model's properties
- Saving the model for later use
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
Remarks
This method sets up the layers of the graph neural network. If the architecture provides specific layers, those are used directly. Otherwise, default layers appropriate for a graph neural network are created.
For Beginners: This method sets up the building blocks of your neural network.
When initializing layers:
- If you provided specific layers in the architecture, those are used
- If not, the network creates a standard set of layers for graph processing
Think of this like assembling the components of the network before training begins.
Predict(Tensor<T>)
Performs a forward pass through the network to make a prediction using a standard input tensor.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process.
Returns
- Tensor<T>
The output tensor containing the prediction.
Remarks
This method passes the input data through each layer of the network sequentially. For a GraphNeuralNetwork, this method assumes the input tensor already contains the necessary node features and adjacency information in a preprocessed format, or that the input is for a portion of the network that uses standard layers rather than graph-specific ones.
For Beginners: While GNNs work best with explicit graph data (provided through the PredictGraph method), this method allows the network to process pre-processed graph data or operate on non-graph portions of the network.
It works by:
- Starting with your input data
- Passing it through each layer of the network in sequence
- Letting each layer transform the data based on its specific function
The result is a prediction based on the trained network's understanding of graph patterns.
PredictGraph(Tensor<T>, Tensor<T>)
Performs a forward pass through the network to generate a prediction from graph data.
public Tensor<T> PredictGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>A tensor containing features for each node in the graph.
adjacencyMatrixTensor<T>A tensor representing the connections between nodes in the graph.
Returns
- Tensor<T>
A tensor containing the prediction for the graph data.
Remarks
This method processes graph data through the network. It takes node features and an adjacency matrix as input and passes them through graph-specific and standard layers, applying appropriate transformations at each step. The method concludes with hybrid pooling to generate the final output. If graph smoothness regularization is enabled, it caches the node representations for auxiliary loss computation.
For Beginners: This method processes a graph (like a social network) through the neural network to make predictions.
You provide two pieces of information:
- nodeFeatures: Information about each node (like age, interests for each person)
- adjacencyMatrix: Information about how nodes are connected (like who is friends with whom)
The method:
- Passes this information through specialized graph layers
- Also passes it through standard neural network layers
- Combines the results using a technique called "hybrid pooling"
- If smoothness regularization is enabled, saves intermediate representations
- Returns a prediction based on the entire graph structure
This is useful for tasks like predicting which users might become friends, which products a customer might like, or how a molecule might behave.
Exceptions
- ArgumentNullException
Thrown when either nodeFeatures or adjacencyMatrix is null.
- ArgumentException
Thrown when nodeFeatures and adjacencyMatrix have incompatible dimensions.
- InvalidOperationException
Thrown when the network encounters an unsupported layer type or invalid output shape.
SerializeNetworkSpecificData(BinaryWriter)
Serializes Graph Neural Network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method writes the Graph Neural Network's specific configuration data to a binary stream. This includes activation function types and any other GNN-specific parameters. This data is needed to reconstruct the GNN when deserializing.
For Beginners: This method saves the special configuration of your GNN.
Think of it like writing down the recipe for your neural network:
- What activation functions it uses at different stages
- How its graph-specific components are configured
- Any other special settings that make this GNN unique
These details are crucial because they define how your GNN processes information, and they need to be saved along with the weights for the model to work correctly when loaded later.
Train(Tensor<T>, Tensor<T>)
Trains the Graph Neural Network on a single input-output pair.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor containing node features and adjacency information.
expectedOutputTensor<T>The expected output tensor for the given input.
Remarks
This method trains the network on a single batch of data. For graph neural networks, the input tensor needs to contain both node features and adjacency information in a structured format. This implementation assumes the input tensor contains node features in the first half and adjacency matrix in the second half.
For Beginners: This method teaches the neural network using example data.
For a graph neural network, the training process:
- Extracts node features and connection information from the input
- Makes a prediction using this graph data
- Compares the prediction to the expected output to calculate error
- Updates the network's internal values to reduce the error
Over time, with many examples, the network learns to make accurate predictions by understanding how nodes in a graph influence each other.
TrainGraph(Tensor<T>, Tensor<T>, Tensor<T>)
Trains the Graph Neural Network directly on graph data.
public void TrainGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> expectedOutput)
Parameters
nodeFeaturesTensor<T>A tensor containing features for each node in the graph.
adjacencyMatrixTensor<T>A tensor representing the connections between nodes in the graph.
expectedOutputTensor<T>The expected output tensor for the given graph input.
Remarks
This method provides a more direct interface for training the network on graph data by explicitly accepting node features and adjacency matrix as separate parameters. This is often more intuitive than combining them into a single input tensor.
For Beginners: This is a more straightforward way to train your graph neural network.
Instead of combining node information and connection information into one input, you can provide them separately:
- nodeFeatures: Information about each node (e.g., user profiles in a social network)
- adjacencyMatrix: Information about connections (e.g., who is friends with whom)
- expectedOutput: What the network should predict for this graph
The network then learns to make predictions based on both the node attributes and how nodes are connected to each other.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network using the provided parameter vector.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing updated parameters for all layers.
Remarks
This method distributes the provided parameter values to each layer in the network. It extracts the appropriate segment of the parameter vector for each layer based on the layer's parameter count.
For Beginners: This method updates all the learned values in the network.
During training, a neural network adjusts its internal values (parameters) to make better predictions. This method:
- Takes a long list of new parameter values
- Figures out which values belong to which layers
- Updates each layer with its corresponding values
Think of it like updating the settings on different parts of a machine to make it work better based on what it has learned.