Table of Contents

Class GraphIsomorphismNetwork<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a Graph Isomorphism Network (GIN) for powerful graph representation learning.

public class GraphIsomorphismNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
GraphIsomorphismNetwork<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

Graph Isomorphism Networks (GIN), introduced by Xu et al. (2019), are provably as powerful as the Weisfeiler-Lehman (WL) graph isomorphism test for distinguishing graph structures. GIN uses sum aggregation with a learnable epsilon parameter and applies a multi-layer perceptron (MLP) for powerful feature transformation.

For Beginners: GIN is optimal for structural graph understanding.

How it works:

  • Sum neighbor features (preserves multiset information)
  • Combine with self features using learnable weighting (1 + epsilon)
  • Transform through a 2-layer MLP
  • Result: maximally expressive graph representation

Example - Chemical Structure Analysis:

  • Distinguishing molecules with subtle structural differences
  • GIN can tell apart molecules that simpler GNNs confuse
  • Critical for drug discovery where small differences matter

Key Features:

  • Provably powerful: As expressive as WL test
  • Learnable epsilon: Optimizes self vs neighbor weighting
  • Two-layer MLP: Provides non-linear transformation capacity
  • Sum aggregation: Preserves structural information

Why GIN is powerful:

  • Mean/max pooling loses information (e.g., can't distinguish {1,1,1} from {1})
  • Sum aggregation preserves multiset: {1,1,1} != {1}
  • MLP can approximate complex functions
  • Learnable epsilon finds optimal self-weighting

Architecture:

  1. Multiple GIN layers with sum aggregation
  2. Each layer has learnable epsilon and 2-layer MLP
  3. Optional graph-level readout for classification

When to use GIN:

  • When structural differentiation is critical
  • Molecular property prediction
  • Chemical compound classification
  • Any task where graph structure similarity matters

Constructors

GraphIsomorphismNetwork(NeuralNetworkArchitecture<T>, int, int, bool, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)

Initializes a new instance of the GraphIsomorphismNetwork<T> class with specified architecture.

public GraphIsomorphismNetwork(NeuralNetworkArchitecture<T> architecture, int mlpHiddenDim = 64, int numLayers = 5, bool learnEpsilon = true, double initialEpsilon = 0, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture defining the structure of the network.

mlpHiddenDim int

Hidden dimension for MLP within GIN layers (default: 64). Used only when creating default layers.

numLayers int

Number of GIN layers (default: 5). Used only when creating default layers.

learnEpsilon bool

Whether to learn epsilon parameter (default: true). Used only when creating default layers.

initialEpsilon double

Initial value for epsilon (default: 0.0). Used only when creating default layers.

optimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for training.

lossFunction ILossFunction<T>

Optional loss function for training.

maxGradNorm double

Maximum gradient norm for clipping (default: 1.0).

Remarks

For Beginners: Creating a GIN network:

// Create architecture for molecular property prediction
var architecture = new NeuralNetworkArchitecture<double>(
    InputType.OneDimensional,
    NeuralNetworkTaskType.MultiClassClassification,
    NetworkComplexity.Simple,
    inputSize: 9,        // Atom features
    outputSize: 2);      // Binary classification

// Create GIN with default layers
var gin = new GraphIsomorphismNetwork<double>(architecture);

// Or create with custom layers by adding them to architecture
var ginCustom = new GraphIsomorphismNetwork<double>(architectureWithCustomLayers);

// Train on molecular graphs
gin.TrainOnGraphs(molecules, adjacencyMatrices, labels, epochs: 100);

Properties

InitialEpsilon

Gets the initial epsilon value for GIN layers.

public double InitialEpsilon { get; }

Property Value

double

IsLoRAEnabled

Gets whether LoRA fine-tuning is currently enabled.

public bool IsLoRAEnabled { get; }

Property Value

bool

LearnEpsilon

Gets whether epsilon is learnable in GIN layers.

public bool LearnEpsilon { get; }

Property Value

bool

LoRARank

Gets the LoRA rank when LoRA is enabled.

public int LoRARank { get; }

Property Value

int

MlpHiddenDim

Gets the hidden dimension size for MLP in each layer.

public int MlpHiddenDim { get; }

Property Value

int

NumLayers

Gets the number of GIN layers in the network.

public int NumLayers { get; }

Property Value

int

Methods

Backward(Tensor<T>)

Performs a backward pass through the network to calculate gradients.

public Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

The gradient of the loss with respect to the network's output.

Returns

Tensor<T>

The gradient of the loss with respect to the network's input.

CreateNewInstance()

Creates a new instance of this network type.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data from a binary reader.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

DisableLoRA()

Disables LoRA fine-tuning and restores original layers.

public void DisableLoRA()

EnableLoRAFineTuning(int, double, bool)

Enables LoRA fine-tuning for parameter-efficient training.

public void EnableLoRAFineTuning(int rank = 8, double alpha = -1, bool freezeBaseLayers = true)

Parameters

rank int
alpha double
freezeBaseLayers bool

Evaluate(Tensor<T>, Tensor<T>, Tensor<T>, bool[])

Evaluates the model on test data and returns accuracy for node classification.

public double Evaluate(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[] testMask)

Parameters

nodeFeatures Tensor<T>

Node feature tensor.

adjacencyMatrix Tensor<T>

Adjacency matrix.

labels Tensor<T>

Ground truth labels.

testMask bool[]

Boolean mask for test nodes.

Returns

double

Classification accuracy on test nodes.

EvaluateGraphs(List<Tensor<T>>, List<Tensor<T>>, Tensor<T>)

Evaluates the model on test graphs and returns accuracy for graph classification.

public double EvaluateGraphs(List<Tensor<T>> graphs, List<Tensor<T>> adjacencyMatrices, Tensor<T> graphLabels)

Parameters

graphs List<Tensor<T>>

List of graph node feature tensors.

adjacencyMatrices List<Tensor<T>>

List of adjacency matrices.

graphLabels Tensor<T>

Ground truth labels for each graph.

Returns

double

Classification accuracy on test graphs.

Forward(Tensor<T>, Tensor<T>)

Performs a forward pass through the network with node features and adjacency matrix.

public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)

Parameters

nodeFeatures Tensor<T>

Node feature tensor of shape [batchSize, numNodes, inputFeatures] or [numNodes, inputFeatures].

adjacencyMatrix Tensor<T>

Adjacency matrix of shape [batchSize, numNodes, numNodes] or [numNodes, numNodes].

Returns

Tensor<T>

The output tensor after processing through all layers.

GetGraphRepresentation(Tensor<T>, Tensor<T>)

Gets graph-level representations using sum, mean, and max pooling combined.

public Tensor<T> GetGraphRepresentation(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)

Parameters

nodeFeatures Tensor<T>

Node feature tensor.

adjacencyMatrix Tensor<T>

Adjacency matrix.

Returns

Tensor<T>

Graph-level representation combining multiple readout strategies.

Remarks

For Beginners: Hierarchical graph representations:

This method creates rich graph-level embeddings by:

  1. Processing through all GIN layers
  2. At each layer, computing sum, mean, and max of node features
  3. Concatenating all layer representations

This captures both local (early layers) and global (later layers) structure.

GetLoRAParameterCount()

Gets the number of trainable LoRA parameters.

public int GetLoRAParameterCount()

Returns

int

GetModelMetadata()

Gets metadata about this model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetParameterCount()

Gets the total number of trainable parameters in the network.

public int GetParameterCount()

Returns

int

GetParameters()

Gets all parameters as a vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

InitializeLayers()

Initializes the layers of the neural network based on the provided architecture.

protected override void InitializeLayers()

MergeLoRAWeights()

Merges LoRA weights into the base layers and disables LoRA mode.

public void MergeLoRAWeights()

Predict(Tensor<T>)

Makes a prediction using the trained network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

Returns

Tensor<T>

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data to a binary writer.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

SetAdjacencyMatrix(Tensor<T>)

Sets the adjacency matrix for graph operations.

public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)

Parameters

adjacencyMatrix Tensor<T>

Train(Tensor<T>, Tensor<T>)

Trains the network on a single batch of data.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>
expectedOutput Tensor<T>

TrainOnGraph(Tensor<T>, Tensor<T>, Tensor<T>, bool[]?, int, double)

Trains the GIN network on a single graph with node classification.

public void TrainOnGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[]? trainMask = null, int epochs = 200, double learningRate = 0.01)

Parameters

nodeFeatures Tensor<T>

Node feature tensor of shape [numNodes, inputFeatures].

adjacencyMatrix Tensor<T>

Adjacency matrix of shape [numNodes, numNodes].

labels Tensor<T>

Label tensor for supervised learning.

trainMask bool[]

Optional boolean mask indicating which nodes to train on.

epochs int

Number of training epochs (default: 200).

learningRate double

Learning rate for optimization (default: 0.01).

TrainOnGraphs(List<Tensor<T>>, List<Tensor<T>>, Tensor<T>, int, double)

Trains the GIN network on multiple graphs for graph classification.

public void TrainOnGraphs(List<Tensor<T>> graphs, List<Tensor<T>> adjacencyMatrices, Tensor<T> graphLabels, int epochs = 100, double learningRate = 0.01)

Parameters

graphs List<Tensor<T>>

List of graph node feature tensors.

adjacencyMatrices List<Tensor<T>>

List of adjacency matrices.

graphLabels Tensor<T>

Labels for each graph.

epochs int

Number of training epochs (default: 100).

learningRate double

Learning rate for optimization (default: 0.01).

Remarks

For Beginners: Graph classification with GIN:

GIN is particularly effective for graph-level tasks like:

  • Molecular property prediction (e.g., toxicity, activity)
  • Social network classification
  • Document classification based on citation graphs

How graph classification works:

  1. Process each graph through GIN layers
  2. Aggregate node features to get graph-level representation
  3. Classify using the aggregated representation

UpdateParameters(Vector<T>)

Updates the parameters of all layers in the network.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for the network.