Class GraphClassificationModel<T>
- Namespace
- AiDotNet.NeuralNetworks.Tasks.Graph
- Assembly
- AiDotNet.dll
Implements a complete neural network model for graph classification tasks.
public class GraphClassificationModel<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphClassificationModel<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Graph classification assigns labels to entire graphs based on their structure and features. The model consists of: 1. Node-level processing (GNN layers) 2. Graph-level pooling (aggregate node embeddings) 3. Classification head (fully connected layers)
For Beginners: This model classifies whole graphs.
Architecture pipeline:
Step 1: Node Encoding
Input: Graph with node features
Process: Stack of GNN layers
Output: Node embeddings [num_nodes, hidden_dim]
Step 2: Graph Pooling (KEY STEP!)
Input: Node embeddings from variable-sized graph
Process: Aggregate to fixed-size representation
Output: Graph embedding [hidden_dim]
Step 3: Classification
Input: Graph embedding [hidden_dim]
Process: Fully connected layers
Output: Class probabilities [num_classes]
Why pooling is crucial:
- Graphs have variable sizes (10 nodes vs 100 nodes)
- Need fixed-size representation for classification
- Like summarizing a book (any length) into a fixed review (200 words)
Example: Molecular toxicity prediction
Molecule (graph) -> GNN layers -> Molecule embedding -> Classifier -> Toxic? (Yes/No)
Small molecule (10 atoms):
10 nodes -> GNN -> 10 embeddings -> Pool -> 1 graph embedding -> Classify
Large molecule (50 atoms):
50 nodes -> GNN -> 50 embeddings -> Pool -> 1 graph embedding -> Classify
Both produce same-sized graph embedding despite different input sizes!
Constructors
GraphClassificationModel(NeuralNetworkArchitecture<T>, int, int, int, double, GraphPooling, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the GraphClassificationModel<T> class.
public GraphClassificationModel(NeuralNetworkArchitecture<T> architecture, int hiddenDim = 64, int embeddingDim = 128, int numGnnLayers = 3, double dropoutRate = 0.5, GraphClassificationModel<T>.GraphPooling poolingType = GraphPooling.Mean, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining input/output sizes and layers.
hiddenDimintHidden dimension for intermediate layers (default: 64).
embeddingDimintDimension of graph embedding after pooling (default: 128).
numGnnLayersintNumber of graph convolutional layers (default: 3).
dropoutRatedoubleDropout rate for regularization (default: 0.5).
poolingTypeGraphClassificationModel<T>.GraphPoolingMethod for pooling node embeddings to graph embedding.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training.
lossFunctionILossFunction<T>Optional loss function for training.
maxGradNormdoubleMaximum gradient norm for clipping (default: 1.0).
Remarks
For Beginners: Creating a graph classification model:
// Create architecture for molecular property prediction
var architecture = new NeuralNetworkArchitecture<double>(
InputType.OneDimensional,
NeuralNetworkTaskType.MultiClassClassification,
NetworkComplexity.Simple,
inputSize: 9, // Atom features
outputSize: 2); // Binary classification (toxic/not toxic)
// Create model with default layers
var model = new GraphClassificationModel<double>(architecture);
// Train on graph classification task
var history = model.TrainOnTask(task, epochs: 100, learningRate: 0.001);
Properties
DropoutRate
Gets the dropout rate for regularization.
public double DropoutRate { get; }
Property Value
EmbeddingDim
Gets the graph embedding dimension after pooling.
public int EmbeddingDim { get; }
Property Value
HiddenDim
Gets the hidden dimension size.
public int HiddenDim { get; }
Property Value
InputFeatures
Gets the number of input features per node.
public int InputFeatures { get; }
Property Value
NumClasses
Gets the number of output classes.
public int NumClasses { get; }
Property Value
NumGnnLayers
Gets the number of GNN layers.
public int NumGnnLayers { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass through the network.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient of loss with respect to output.
Returns
- Tensor<T>
Gradient with respect to input.
CreateNewInstance()
Creates a new instance of this network type for cloning or deserialization.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
EvaluateOnTask(GraphClassificationTask<T>)
Evaluates the model on test graphs.
public double EvaluateOnTask(GraphClassificationTask<T> task)
Parameters
taskGraphClassificationTask<T>
Returns
Forward(Tensor<T>)
Performs a forward pass through the network.
public Tensor<T> Forward(Tensor<T> nodeFeatures)
Parameters
nodeFeaturesTensor<T>Node feature tensor [num_nodes, input_features].
Returns
- Tensor<T>
Output predictions [num_classes].
GetModelMetadata()
Gets metadata about this model for serialization and identification.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets all parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
Predict(Tensor<T>)
Makes a prediction using the trained network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor containing node features.
Returns
- Tensor<T>
The prediction tensor with class probabilities.
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
SetAdjacencyMatrix(Tensor<T>)
Sets the adjacency matrix for all graph layers in the model.
public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)
Parameters
adjacencyMatrixTensor<T>The graph adjacency matrix.
Train(Tensor<T>, Tensor<T>)
Trains the network on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input node features.
expectedOutputTensor<T>The expected output (labels).
TrainOnTask(GraphClassificationTask<T>, int, double)
Trains the model on a graph classification task.
public Dictionary<string, List<double>> TrainOnTask(GraphClassificationTask<T> task, int epochs, double learningRate = 0.001)
Parameters
taskGraphClassificationTask<T>The graph classification task with training/validation/test graphs.
epochsintNumber of training epochs.
learningRatedoubleLearning rate for optimization.
Returns
- Dictionary<string, List<double>>
Training history with loss and accuracy per epoch.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.