Class GraphIsomorphismNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Graph Isomorphism Network (GIN) for powerful graph representation learning.
public class GraphIsomorphismNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphIsomorphismNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Graph Isomorphism Networks (GIN), introduced by Xu et al. (2019), are provably as powerful as the Weisfeiler-Lehman (WL) graph isomorphism test for distinguishing graph structures. GIN uses sum aggregation with a learnable epsilon parameter and applies a multi-layer perceptron (MLP) for powerful feature transformation.
For Beginners: GIN is optimal for structural graph understanding.
How it works:
- Sum neighbor features (preserves multiset information)
- Combine with self features using learnable weighting (1 + epsilon)
- Transform through a 2-layer MLP
- Result: maximally expressive graph representation
Example - Chemical Structure Analysis:
- Distinguishing molecules with subtle structural differences
- GIN can tell apart molecules that simpler GNNs confuse
- Critical for drug discovery where small differences matter
Key Features:
- Provably powerful: As expressive as WL test
- Learnable epsilon: Optimizes self vs neighbor weighting
- Two-layer MLP: Provides non-linear transformation capacity
- Sum aggregation: Preserves structural information
Why GIN is powerful:
- Mean/max pooling loses information (e.g., can't distinguish {1,1,1} from {1})
- Sum aggregation preserves multiset: {1,1,1} != {1}
- MLP can approximate complex functions
- Learnable epsilon finds optimal self-weighting
Architecture:
- Multiple GIN layers with sum aggregation
- Each layer has learnable epsilon and 2-layer MLP
- Optional graph-level readout for classification
When to use GIN:
- When structural differentiation is critical
- Molecular property prediction
- Chemical compound classification
- Any task where graph structure similarity matters
Constructors
GraphIsomorphismNetwork(NeuralNetworkArchitecture<T>, int, int, bool, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the GraphIsomorphismNetwork<T> class with specified architecture.
public GraphIsomorphismNetwork(NeuralNetworkArchitecture<T> architecture, int mlpHiddenDim = 64, int numLayers = 5, bool learnEpsilon = true, double initialEpsilon = 0, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining the structure of the network.
mlpHiddenDimintHidden dimension for MLP within GIN layers (default: 64). Used only when creating default layers.
numLayersintNumber of GIN layers (default: 5). Used only when creating default layers.
learnEpsilonboolWhether to learn epsilon parameter (default: true). Used only when creating default layers.
initialEpsilondoubleInitial value for epsilon (default: 0.0). Used only when creating default layers.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training.
lossFunctionILossFunction<T>Optional loss function for training.
maxGradNormdoubleMaximum gradient norm for clipping (default: 1.0).
Remarks
For Beginners: Creating a GIN network:
// Create architecture for molecular property prediction
var architecture = new NeuralNetworkArchitecture<double>(
InputType.OneDimensional,
NeuralNetworkTaskType.MultiClassClassification,
NetworkComplexity.Simple,
inputSize: 9, // Atom features
outputSize: 2); // Binary classification
// Create GIN with default layers
var gin = new GraphIsomorphismNetwork<double>(architecture);
// Or create with custom layers by adding them to architecture
var ginCustom = new GraphIsomorphismNetwork<double>(architectureWithCustomLayers);
// Train on molecular graphs
gin.TrainOnGraphs(molecules, adjacencyMatrices, labels, epochs: 100);
Properties
InitialEpsilon
Gets the initial epsilon value for GIN layers.
public double InitialEpsilon { get; }
Property Value
IsLoRAEnabled
Gets whether LoRA fine-tuning is currently enabled.
public bool IsLoRAEnabled { get; }
Property Value
LearnEpsilon
Gets whether epsilon is learnable in GIN layers.
public bool LearnEpsilon { get; }
Property Value
LoRARank
Gets the LoRA rank when LoRA is enabled.
public int LoRARank { get; }
Property Value
MlpHiddenDim
Gets the hidden dimension size for MLP in each layer.
public int MlpHiddenDim { get; }
Property Value
NumLayers
Gets the number of GIN layers in the network.
public int NumLayers { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass through the network to calculate gradients.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the network's output.
Returns
- Tensor<T>
The gradient of the loss with respect to the network's input.
CreateNewInstance()
Creates a new instance of this network type.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
DisableLoRA()
Disables LoRA fine-tuning and restores original layers.
public void DisableLoRA()
EnableLoRAFineTuning(int, double, bool)
Enables LoRA fine-tuning for parameter-efficient training.
public void EnableLoRAFineTuning(int rank = 8, double alpha = -1, bool freezeBaseLayers = true)
Parameters
Evaluate(Tensor<T>, Tensor<T>, Tensor<T>, bool[])
Evaluates the model on test data and returns accuracy for node classification.
public double Evaluate(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[] testMask)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix.
labelsTensor<T>Ground truth labels.
testMaskbool[]Boolean mask for test nodes.
Returns
- double
Classification accuracy on test nodes.
EvaluateGraphs(List<Tensor<T>>, List<Tensor<T>>, Tensor<T>)
Evaluates the model on test graphs and returns accuracy for graph classification.
public double EvaluateGraphs(List<Tensor<T>> graphs, List<Tensor<T>> adjacencyMatrices, Tensor<T> graphLabels)
Parameters
graphsList<Tensor<T>>List of graph node feature tensors.
adjacencyMatricesList<Tensor<T>>List of adjacency matrices.
graphLabelsTensor<T>Ground truth labels for each graph.
Returns
- double
Classification accuracy on test graphs.
Forward(Tensor<T>, Tensor<T>)
Performs a forward pass through the network with node features and adjacency matrix.
public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [batchSize, numNodes, inputFeatures] or [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [batchSize, numNodes, numNodes] or [numNodes, numNodes].
Returns
- Tensor<T>
The output tensor after processing through all layers.
GetGraphRepresentation(Tensor<T>, Tensor<T>)
Gets graph-level representations using sum, mean, and max pooling combined.
public Tensor<T> GetGraphRepresentation(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix.
Returns
- Tensor<T>
Graph-level representation combining multiple readout strategies.
Remarks
For Beginners: Hierarchical graph representations:
This method creates rich graph-level embeddings by:
- Processing through all GIN layers
- At each layer, computing sum, mean, and max of node features
- Concatenating all layer representations
This captures both local (early layers) and global (later layers) structure.
GetLoRAParameterCount()
Gets the number of trainable LoRA parameters.
public int GetLoRAParameterCount()
Returns
GetModelMetadata()
Gets metadata about this model.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameterCount()
Gets the total number of trainable parameters in the network.
public int GetParameterCount()
Returns
GetParameters()
Gets all parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
MergeLoRAWeights()
Merges LoRA weights into the base layers and disables LoRA mode.
public void MergeLoRAWeights()
Predict(Tensor<T>)
Makes a prediction using the trained network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
SetAdjacencyMatrix(Tensor<T>)
Sets the adjacency matrix for graph operations.
public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)
Parameters
adjacencyMatrixTensor<T>
Train(Tensor<T>, Tensor<T>)
Trains the network on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>expectedOutputTensor<T>
TrainOnGraph(Tensor<T>, Tensor<T>, Tensor<T>, bool[]?, int, double)
Trains the GIN network on a single graph with node classification.
public void TrainOnGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[]? trainMask = null, int epochs = 200, double learningRate = 0.01)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [numNodes, numNodes].
labelsTensor<T>Label tensor for supervised learning.
trainMaskbool[]Optional boolean mask indicating which nodes to train on.
epochsintNumber of training epochs (default: 200).
learningRatedoubleLearning rate for optimization (default: 0.01).
TrainOnGraphs(List<Tensor<T>>, List<Tensor<T>>, Tensor<T>, int, double)
Trains the GIN network on multiple graphs for graph classification.
public void TrainOnGraphs(List<Tensor<T>> graphs, List<Tensor<T>> adjacencyMatrices, Tensor<T> graphLabels, int epochs = 100, double learningRate = 0.01)
Parameters
graphsList<Tensor<T>>List of graph node feature tensors.
adjacencyMatricesList<Tensor<T>>List of adjacency matrices.
graphLabelsTensor<T>Labels for each graph.
epochsintNumber of training epochs (default: 100).
learningRatedoubleLearning rate for optimization (default: 0.01).
Remarks
For Beginners: Graph classification with GIN:
GIN is particularly effective for graph-level tasks like:
- Molecular property prediction (e.g., toxicity, activity)
- Social network classification
- Document classification based on citation graphs
How graph classification works:
- Process each graph through GIN layers
- Aggregate node features to get graph-level representation
- Classify using the aggregated representation
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.