Table of Contents

Class GraphAttentionNetwork<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a Graph Attention Network (GAT) that uses attention mechanisms to process graph-structured data.

public class GraphAttentionNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
GraphAttentionNetwork<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

Graph Attention Networks introduce attention mechanisms to graph neural networks, allowing the model to learn which neighbors are most important for each node. Unlike GCN which treats all neighbors equally, GAT learns attention weights that determine how much each neighbor contributes to a node's representation.

For Beginners: GAT is like having a smart filter for your social network.

How it works:

  • Each node looks at its neighbors and decides which ones are most important
  • Important neighbors get more "attention" (higher weights)
  • Less relevant neighbors get less attention

Example - Movie Recommendations:

  • You're a node connected to movies you've watched
  • Some movies better represent your taste than others
  • GAT learns to pay more attention to movies that define your preferences
  • Result: Better recommendations by focusing on what matters most

Key Features:

  • Multi-head attention: Multiple attention "perspectives" for robustness
  • Dynamic weights: Attention weights are learned, not fixed
  • Dropout support: Prevents overfitting during training
  • Configurable heads: Adjust number of attention heads for your task

Architecture: The standard GAT architecture consists of:

  1. Multiple GAT layers with attention mechanisms
  2. Optional dropout between layers
  3. Final classification or regression head

When to use GAT:

  • When some neighbors are more informative than others
  • When you need interpretable importance scores
  • For heterogeneous graphs where relationships vary in importance
  • Citation networks, social networks, knowledge graphs

Constructors

GraphAttentionNetwork(NeuralNetworkArchitecture<T>, int, int, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)

Initializes a new instance of the GraphAttentionNetwork<T> class with specified architecture.

public GraphAttentionNetwork(NeuralNetworkArchitecture<T> architecture, int numHeads = 8, int numLayers = 2, double dropoutRate = 0.6, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture defining the structure of the network.

numHeads int

Number of attention heads per layer (default: 8). Used only when creating default layers.

numLayers int

Number of GAT layers (default: 2). Used only when creating default layers.

dropoutRate double

Dropout rate for attention coefficients (default: 0.6).

optimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for training.

lossFunction ILossFunction<T>

Optional loss function for training.

maxGradNorm double

Maximum gradient norm for clipping (default: 1.0).

Remarks

For Beginners: Creating a GAT network:

// Create architecture for node classification
var architecture = new NeuralNetworkArchitecture<double>(
    InputType.OneDimensional,
    NeuralNetworkTaskType.MultiClassClassification,
    NetworkComplexity.Simple,
    inputSize: 1433,    // Cora has 1433 word features
    outputSize: 7);     // 7 paper categories

// Create GAT with default layers
var gat = new GraphAttentionNetwork<double>(architecture);

// Or create with custom layers by adding them to architecture
var gatCustom = new GraphAttentionNetwork<double>(architectureWithCustomLayers);

// Train on graph data
gat.TrainOnGraph(nodeFeatures, adjacencyMatrix, labels, epochs: 200);

Properties

DropoutRate

Gets the dropout rate applied to attention coefficients during training.

public double DropoutRate { get; }

Property Value

double

HiddenDim

Gets the hidden dimension size for each layer.

public int HiddenDim { get; }

Property Value

int

IsLoRAEnabled

Gets whether LoRA fine-tuning is currently enabled.

public bool IsLoRAEnabled { get; }

Property Value

bool

LoRARank

Gets the LoRA rank when LoRA is enabled.

public int LoRARank { get; }

Property Value

int

NumHeads

Gets the number of attention heads used in each GAT layer.

public int NumHeads { get; }

Property Value

int

NumLayers

Gets the number of GAT layers in the network.

public int NumLayers { get; }

Property Value

int

Methods

Backward(Tensor<T>)

Performs a backward pass through the network to calculate gradients.

public Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

The gradient of the loss with respect to the network's output.

Returns

Tensor<T>

The gradient of the loss with respect to the network's input.

CreateNewInstance()

Creates a new instance of this network type for cloning or deserialization.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

A new GraphAttentionNetwork instance.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data from a binary reader.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

The binary reader to deserialize from.

DisableLoRA()

Disables LoRA fine-tuning and restores original layers.

public void DisableLoRA()

Remarks

This removes the LoRA adapters and restores the original base layers. Any LoRA adaptations that were not merged will be lost.

EnableLoRAFineTuning(int, double, bool)

Enables LoRA (Low-Rank Adaptation) fine-tuning for parameter-efficient training.

public void EnableLoRAFineTuning(int rank = 8, double alpha = -1, bool freezeBaseLayers = true)

Parameters

rank int

The rank of the LoRA decomposition (default: 8).

alpha double

The LoRA scaling factor (default: same as rank).

freezeBaseLayers bool

Whether to freeze base layer parameters (default: true).

Remarks

For Beginners: LoRA allows you to fine-tune the GAT network with far fewer trainable parameters:

// Create and pre-train a GAT network
var gat = new GraphAttentionNetwork<double>(128, 64, 7, numHeads: 8);
gat.TrainOnGraph(features, adjacency, labels, epochs: 200);

// Enable LoRA for efficient fine-tuning on new task
gat.EnableLoRAFineTuning(rank: 8, alpha: 16);

// Now only ~4% of parameters are trainable!
Console.WriteLine($"LoRA parameters: {gat.GetLoRAParameterCount()}");
Console.WriteLine($"Total parameters: {gat.GetParameterCount()}");

// Fine-tune on new data
gat.TrainOnGraph(newFeatures, newAdjacency, newLabels, epochs: 50);

// Optionally merge LoRA weights for deployment
gat.MergeLoRAWeights();

Evaluate(Tensor<T>, Tensor<T>, Tensor<T>, bool[])

Evaluates the model on test data and returns accuracy.

public double Evaluate(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[] testMask)

Parameters

nodeFeatures Tensor<T>

Node feature tensor.

adjacencyMatrix Tensor<T>

Adjacency matrix.

labels Tensor<T>

Ground truth labels.

testMask bool[]

Boolean mask for test nodes.

Returns

double

Classification accuracy on test nodes.

Forward(Tensor<T>, Tensor<T>)

Performs a forward pass through the network with node features and adjacency matrix.

public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)

Parameters

nodeFeatures Tensor<T>

Node feature tensor of shape [batchSize, numNodes, inputFeatures] or [numNodes, inputFeatures].

adjacencyMatrix Tensor<T>

Adjacency matrix of shape [batchSize, numNodes, numNodes] or [numNodes, numNodes].

Returns

Tensor<T>

The output tensor after processing through all layers.

GetAttentionWeights()

Gets attention weights from all GAT layers for interpretability.

public List<Tensor<T>?> GetAttentionWeights()

Returns

List<Tensor<T>>

List of attention weight tensors (currently returns nulls as implementation is pending).

Remarks

Note: This method is a placeholder. Full attention coefficient retrieval requires exposing internal state from GraphAttentionLayer, which will be added in a future update.

GetLoRAParameterCount()

Gets the number of trainable LoRA parameters when LoRA is enabled.

public int GetLoRAParameterCount()

Returns

int

The count of LoRA parameters, or 0 if LoRA is not enabled.

GetLoRATrainablePercentage()

Gets the percentage of parameters that are trainable when using LoRA.

public double GetLoRATrainablePercentage()

Returns

double

The percentage of trainable parameters (0-100).

GetModelMetadata()

Gets metadata about this model for serialization and identification.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

Model metadata including type and configuration.

GetParameterCount()

Gets the total number of trainable parameters in the network.

public int GetParameterCount()

Returns

int

GetParameters()

Gets all parameters as a vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

InitializeLayers()

Initializes the layers of the neural network based on the provided architecture.

protected override void InitializeLayers()

MergeLoRAWeights()

Merges LoRA weights into the base layers and disables LoRA mode.

public void MergeLoRAWeights()

Remarks

For Beginners: After fine-tuning with LoRA, you can "bake in" the learned adaptations to create a standard network for deployment:

  • Before merge: Forward pass requires computing both base and LoRA outputs
  • After merge: Single forward pass through merged layers (faster)

This is useful when deploying the fine-tuned model to production where you want maximum inference speed and don't need to track LoRA parameters separately.

Predict(Tensor<T>)

Makes a prediction using the trained network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor containing node features.

Returns

Tensor<T>

The prediction tensor.

Remarks

For Beginners: This is the main method for using a trained GAT network. Pass in node features and get predictions back. For classification, the output will be class probabilities for each node. If no adjacency matrix has been set, a fully-connected adjacency matrix is generated for convenience. Note that this treats every node as connected to every other node, which can mask real graph structure; call SetAdjacencyMatrix(Tensor<T>) to supply the true graph.

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data to a binary writer.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

The binary writer to serialize to.

SetAdjacencyMatrix(Tensor<T>)

Sets the adjacency matrix for graph operations.

public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)

Parameters

adjacencyMatrix Tensor<T>

The adjacency matrix defining graph structure (shape [numNodes, numNodes]).

Train(Tensor<T>, Tensor<T>)

Trains the network on a single batch of data.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input node features.

expectedOutput Tensor<T>

The expected output (labels).

Remarks

For Beginners: This method performs one training step. For full training, call TrainOnGraph which handles multiple epochs and adjacency matrix setup. If no adjacency matrix has been set, a fully-connected adjacency matrix is generated for convenience. This means every node is treated as connected to every other node, which can hide the true graph structure unless you provide an explicit adjacency matrix.

TrainOnGraph(Tensor<T>, Tensor<T>, Tensor<T>, bool[]?, int, double)

Trains the GAT network on graph-structured data.

public void TrainOnGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[]? trainMask = null, int epochs = 200, double learningRate = 0.005)

Parameters

nodeFeatures Tensor<T>

Node feature tensor of shape [numNodes, inputFeatures].

adjacencyMatrix Tensor<T>

Adjacency matrix of shape [numNodes, numNodes].

labels Tensor<T>

Label tensor for supervised learning.

trainMask bool[]

Optional boolean mask indicating which nodes to train on.

epochs int

Number of training epochs (default: 200).

learningRate double

Learning rate for optimization (default: 0.005).

UpdateParameters(Vector<T>)

Updates the parameters of all layers in the network.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for the network.