Class GraphAttentionNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Graph Attention Network (GAT) that uses attention mechanisms to process graph-structured data.
public class GraphAttentionNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphAttentionNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Graph Attention Networks introduce attention mechanisms to graph neural networks, allowing the model to learn which neighbors are most important for each node. Unlike GCN which treats all neighbors equally, GAT learns attention weights that determine how much each neighbor contributes to a node's representation.
For Beginners: GAT is like having a smart filter for your social network.
How it works:
- Each node looks at its neighbors and decides which ones are most important
- Important neighbors get more "attention" (higher weights)
- Less relevant neighbors get less attention
Example - Movie Recommendations:
- You're a node connected to movies you've watched
- Some movies better represent your taste than others
- GAT learns to pay more attention to movies that define your preferences
- Result: Better recommendations by focusing on what matters most
Key Features:
- Multi-head attention: Multiple attention "perspectives" for robustness
- Dynamic weights: Attention weights are learned, not fixed
- Dropout support: Prevents overfitting during training
- Configurable heads: Adjust number of attention heads for your task
Architecture: The standard GAT architecture consists of:
- Multiple GAT layers with attention mechanisms
- Optional dropout between layers
- Final classification or regression head
When to use GAT:
- When some neighbors are more informative than others
- When you need interpretable importance scores
- For heterogeneous graphs where relationships vary in importance
- Citation networks, social networks, knowledge graphs
Constructors
GraphAttentionNetwork(NeuralNetworkArchitecture<T>, int, int, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the GraphAttentionNetwork<T> class with specified architecture.
public GraphAttentionNetwork(NeuralNetworkArchitecture<T> architecture, int numHeads = 8, int numLayers = 2, double dropoutRate = 0.6, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining the structure of the network.
numHeadsintNumber of attention heads per layer (default: 8). Used only when creating default layers.
numLayersintNumber of GAT layers (default: 2). Used only when creating default layers.
dropoutRatedoubleDropout rate for attention coefficients (default: 0.6).
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training.
lossFunctionILossFunction<T>Optional loss function for training.
maxGradNormdoubleMaximum gradient norm for clipping (default: 1.0).
Remarks
For Beginners: Creating a GAT network:
// Create architecture for node classification
var architecture = new NeuralNetworkArchitecture<double>(
InputType.OneDimensional,
NeuralNetworkTaskType.MultiClassClassification,
NetworkComplexity.Simple,
inputSize: 1433, // Cora has 1433 word features
outputSize: 7); // 7 paper categories
// Create GAT with default layers
var gat = new GraphAttentionNetwork<double>(architecture);
// Or create with custom layers by adding them to architecture
var gatCustom = new GraphAttentionNetwork<double>(architectureWithCustomLayers);
// Train on graph data
gat.TrainOnGraph(nodeFeatures, adjacencyMatrix, labels, epochs: 200);
Properties
DropoutRate
Gets the dropout rate applied to attention coefficients during training.
public double DropoutRate { get; }
Property Value
HiddenDim
Gets the hidden dimension size for each layer.
public int HiddenDim { get; }
Property Value
IsLoRAEnabled
Gets whether LoRA fine-tuning is currently enabled.
public bool IsLoRAEnabled { get; }
Property Value
LoRARank
Gets the LoRA rank when LoRA is enabled.
public int LoRARank { get; }
Property Value
NumHeads
Gets the number of attention heads used in each GAT layer.
public int NumHeads { get; }
Property Value
NumLayers
Gets the number of GAT layers in the network.
public int NumLayers { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass through the network to calculate gradients.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the network's output.
Returns
- Tensor<T>
The gradient of the loss with respect to the network's input.
CreateNewInstance()
Creates a new instance of this network type for cloning or deserialization.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new GraphAttentionNetwork instance.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to deserialize from.
DisableLoRA()
Disables LoRA fine-tuning and restores original layers.
public void DisableLoRA()
Remarks
This removes the LoRA adapters and restores the original base layers. Any LoRA adaptations that were not merged will be lost.
EnableLoRAFineTuning(int, double, bool)
Enables LoRA (Low-Rank Adaptation) fine-tuning for parameter-efficient training.
public void EnableLoRAFineTuning(int rank = 8, double alpha = -1, bool freezeBaseLayers = true)
Parameters
rankintThe rank of the LoRA decomposition (default: 8).
alphadoubleThe LoRA scaling factor (default: same as rank).
freezeBaseLayersboolWhether to freeze base layer parameters (default: true).
Remarks
For Beginners: LoRA allows you to fine-tune the GAT network with far fewer trainable parameters:
// Create and pre-train a GAT network
var gat = new GraphAttentionNetwork<double>(128, 64, 7, numHeads: 8);
gat.TrainOnGraph(features, adjacency, labels, epochs: 200);
// Enable LoRA for efficient fine-tuning on new task
gat.EnableLoRAFineTuning(rank: 8, alpha: 16);
// Now only ~4% of parameters are trainable!
Console.WriteLine($"LoRA parameters: {gat.GetLoRAParameterCount()}");
Console.WriteLine($"Total parameters: {gat.GetParameterCount()}");
// Fine-tune on new data
gat.TrainOnGraph(newFeatures, newAdjacency, newLabels, epochs: 50);
// Optionally merge LoRA weights for deployment
gat.MergeLoRAWeights();
Evaluate(Tensor<T>, Tensor<T>, Tensor<T>, bool[])
Evaluates the model on test data and returns accuracy.
public double Evaluate(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[] testMask)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix.
labelsTensor<T>Ground truth labels.
testMaskbool[]Boolean mask for test nodes.
Returns
- double
Classification accuracy on test nodes.
Forward(Tensor<T>, Tensor<T>)
Performs a forward pass through the network with node features and adjacency matrix.
public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [batchSize, numNodes, inputFeatures] or [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [batchSize, numNodes, numNodes] or [numNodes, numNodes].
Returns
- Tensor<T>
The output tensor after processing through all layers.
GetAttentionWeights()
Gets attention weights from all GAT layers for interpretability.
public List<Tensor<T>?> GetAttentionWeights()
Returns
- List<Tensor<T>>
List of attention weight tensors (currently returns nulls as implementation is pending).
Remarks
Note: This method is a placeholder. Full attention coefficient retrieval requires exposing internal state from GraphAttentionLayer, which will be added in a future update.
GetLoRAParameterCount()
Gets the number of trainable LoRA parameters when LoRA is enabled.
public int GetLoRAParameterCount()
Returns
- int
The count of LoRA parameters, or 0 if LoRA is not enabled.
GetLoRATrainablePercentage()
Gets the percentage of parameters that are trainable when using LoRA.
public double GetLoRATrainablePercentage()
Returns
- double
The percentage of trainable parameters (0-100).
GetModelMetadata()
Gets metadata about this model for serialization and identification.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
Model metadata including type and configuration.
GetParameterCount()
Gets the total number of trainable parameters in the network.
public int GetParameterCount()
Returns
GetParameters()
Gets all parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
MergeLoRAWeights()
Merges LoRA weights into the base layers and disables LoRA mode.
public void MergeLoRAWeights()
Remarks
For Beginners: After fine-tuning with LoRA, you can "bake in" the learned adaptations to create a standard network for deployment:
- Before merge: Forward pass requires computing both base and LoRA outputs
- After merge: Single forward pass through merged layers (faster)
This is useful when deploying the fine-tuned model to production where you want maximum inference speed and don't need to track LoRA parameters separately.
Predict(Tensor<T>)
Makes a prediction using the trained network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor containing node features.
Returns
- Tensor<T>
The prediction tensor.
Remarks
For Beginners: This is the main method for using a trained GAT network. Pass in node features and get predictions back. For classification, the output will be class probabilities for each node. If no adjacency matrix has been set, a fully-connected adjacency matrix is generated for convenience. Note that this treats every node as connected to every other node, which can mask real graph structure; call SetAdjacencyMatrix(Tensor<T>) to supply the true graph.
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to serialize to.
SetAdjacencyMatrix(Tensor<T>)
Sets the adjacency matrix for graph operations.
public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)
Parameters
adjacencyMatrixTensor<T>The adjacency matrix defining graph structure (shape [numNodes, numNodes]).
Train(Tensor<T>, Tensor<T>)
Trains the network on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input node features.
expectedOutputTensor<T>The expected output (labels).
Remarks
For Beginners: This method performs one training step. For full training, call TrainOnGraph which handles multiple epochs and adjacency matrix setup. If no adjacency matrix has been set, a fully-connected adjacency matrix is generated for convenience. This means every node is treated as connected to every other node, which can hide the true graph structure unless you provide an explicit adjacency matrix.
TrainOnGraph(Tensor<T>, Tensor<T>, Tensor<T>, bool[]?, int, double)
Trains the GAT network on graph-structured data.
public void TrainOnGraph(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, Tensor<T> labels, bool[]? trainMask = null, int epochs = 200, double learningRate = 0.005)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [numNodes, numNodes].
labelsTensor<T>Label tensor for supervised learning.
trainMaskbool[]Optional boolean mask indicating which nodes to train on.
epochsintNumber of training epochs (default: 200).
learningRatedoubleLearning rate for optimization (default: 0.005).
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.