Table of Contents

Class GraphClassificationModel<T>

Namespace
AiDotNet.NeuralNetworks.Tasks.Graph
Assembly
AiDotNet.dll

Implements a complete neural network model for graph classification tasks.

public class GraphClassificationModel<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
GraphClassificationModel<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

Graph classification assigns labels to entire graphs based on their structure and features. The model consists of: 1. Node-level processing (GNN layers) 2. Graph-level pooling (aggregate node embeddings) 3. Classification head (fully connected layers)

For Beginners: This model classifies whole graphs.

Architecture pipeline:

Step 1: Node Encoding
Input: Graph with node features
Process: Stack of GNN layers
Output: Node embeddings [num_nodes, hidden_dim]

Step 2: Graph Pooling (KEY STEP!)
Input: Node embeddings from variable-sized graph
Process: Aggregate to fixed-size representation
Output: Graph embedding [hidden_dim]

Step 3: Classification
Input: Graph embedding [hidden_dim]
Process: Fully connected layers
Output: Class probabilities [num_classes]

Why pooling is crucial:

  • Graphs have variable sizes (10 nodes vs 100 nodes)
  • Need fixed-size representation for classification
  • Like summarizing a book (any length) into a fixed review (200 words)

Example: Molecular toxicity prediction

Molecule (graph) -> GNN layers -> Molecule embedding -> Classifier -> Toxic? (Yes/No)

Small molecule (10 atoms):
  10 nodes -> GNN -> 10 embeddings -> Pool -> 1 graph embedding -> Classify

Large molecule (50 atoms):
  50 nodes -> GNN -> 50 embeddings -> Pool -> 1 graph embedding -> Classify

Both produce same-sized graph embedding despite different input sizes!

Constructors

GraphClassificationModel(NeuralNetworkArchitecture<T>, int, int, int, double, GraphPooling, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)

Initializes a new instance of the GraphClassificationModel<T> class.

public GraphClassificationModel(NeuralNetworkArchitecture<T> architecture, int hiddenDim = 64, int embeddingDim = 128, int numGnnLayers = 3, double dropoutRate = 0.5, GraphClassificationModel<T>.GraphPooling poolingType = GraphPooling.Mean, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture defining input/output sizes and layers.

hiddenDim int

Hidden dimension for intermediate layers (default: 64).

embeddingDim int

Dimension of graph embedding after pooling (default: 128).

numGnnLayers int

Number of graph convolutional layers (default: 3).

dropoutRate double

Dropout rate for regularization (default: 0.5).

poolingType GraphClassificationModel<T>.GraphPooling

Method for pooling node embeddings to graph embedding.

optimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for training.

lossFunction ILossFunction<T>

Optional loss function for training.

maxGradNorm double

Maximum gradient norm for clipping (default: 1.0).

Remarks

For Beginners: Creating a graph classification model:

// Create architecture for molecular property prediction
var architecture = new NeuralNetworkArchitecture<double>(
    InputType.OneDimensional,
    NeuralNetworkTaskType.MultiClassClassification,
    NetworkComplexity.Simple,
    inputSize: 9,      // Atom features
    outputSize: 2);    // Binary classification (toxic/not toxic)

// Create model with default layers
var model = new GraphClassificationModel<double>(architecture);

// Train on graph classification task
var history = model.TrainOnTask(task, epochs: 100, learningRate: 0.001);

Properties

DropoutRate

Gets the dropout rate for regularization.

public double DropoutRate { get; }

Property Value

double

EmbeddingDim

Gets the graph embedding dimension after pooling.

public int EmbeddingDim { get; }

Property Value

int

HiddenDim

Gets the hidden dimension size.

public int HiddenDim { get; }

Property Value

int

InputFeatures

Gets the number of input features per node.

public int InputFeatures { get; }

Property Value

int

NumClasses

Gets the number of output classes.

public int NumClasses { get; }

Property Value

int

NumGnnLayers

Gets the number of GNN layers.

public int NumGnnLayers { get; }

Property Value

int

Methods

Backward(Tensor<T>)

Performs a backward pass through the network.

public Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Gradient of loss with respect to output.

Returns

Tensor<T>

Gradient with respect to input.

CreateNewInstance()

Creates a new instance of this network type for cloning or deserialization.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data from a binary reader.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

EvaluateOnTask(GraphClassificationTask<T>)

Evaluates the model on test graphs.

public double EvaluateOnTask(GraphClassificationTask<T> task)

Parameters

task GraphClassificationTask<T>

Returns

double

Forward(Tensor<T>)

Performs a forward pass through the network.

public Tensor<T> Forward(Tensor<T> nodeFeatures)

Parameters

nodeFeatures Tensor<T>

Node feature tensor [num_nodes, input_features].

Returns

Tensor<T>

Output predictions [num_classes].

GetModelMetadata()

Gets metadata about this model for serialization and identification.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetParameters()

Gets all parameters as a vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

InitializeLayers()

Initializes the layers of the neural network based on the provided architecture.

protected override void InitializeLayers()

Predict(Tensor<T>)

Makes a prediction using the trained network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor containing node features.

Returns

Tensor<T>

The prediction tensor with class probabilities.

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data to a binary writer.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

SetAdjacencyMatrix(Tensor<T>)

Sets the adjacency matrix for all graph layers in the model.

public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)

Parameters

adjacencyMatrix Tensor<T>

The graph adjacency matrix.

Train(Tensor<T>, Tensor<T>)

Trains the network on a single batch of data.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input node features.

expectedOutput Tensor<T>

The expected output (labels).

TrainOnTask(GraphClassificationTask<T>, int, double)

Trains the model on a graph classification task.

public Dictionary<string, List<double>> TrainOnTask(GraphClassificationTask<T> task, int epochs, double learningRate = 0.001)

Parameters

task GraphClassificationTask<T>

The graph classification task with training/validation/test graphs.

epochs int

Number of training epochs.

learningRate double

Learning rate for optimization.

Returns

Dictionary<string, List<double>>

Training history with loss and accuracy per epoch.

UpdateParameters(Vector<T>)

Updates the parameters of all layers in the network.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for the network.