Class LinkPredictionModel<T>
- Namespace
- AiDotNet.NeuralNetworks.Tasks.Graph
- Assembly
- AiDotNet.dll
Implements a complete neural network model for link prediction tasks on graphs.
public class LinkPredictionModel<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
LinkPredictionModel<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Link prediction predicts whether edges should exist between node pairs using: - Node features - Graph structure - Learned node embeddings
For Beginners: This model predicts connections between nodes.
How it works:
Encode: Learn embeddings for all nodes using GNN layers
Input: Node features + Graph structure Process: Stack of graph conv layers Output: Node embeddings [num_nodes, embedding_dim]Decode: Score node pairs to predict edges
Input: Node pair (i, j) Compute: score = f(embedding[i], embedding[j]) Common functions: - Dot product: z_i * z_j - Concatenation + MLP: MLP([z_i || z_j]) - Distance-based: -||z_i - z_j||^2Train: Learn to score existing edges high, non-existing edges low
Example:
Friend recommendation:
- Encode users as embeddings using friend network
- For user pair (Alice, Bob):
* Compute score from their embeddings
* High score -> Likely to be friends
* Low score -> Unlikely to be friends
Constructors
LinkPredictionModel(NeuralNetworkArchitecture<T>, int, int, int, double, LinkPredictionDecoder, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the LinkPredictionModel<T> class.
public LinkPredictionModel(NeuralNetworkArchitecture<T> architecture, int hiddenDim = 64, int embeddingDim = 32, int numLayers = 2, double dropoutRate = 0.5, LinkPredictionModel<T>.LinkPredictionDecoder decoderType = LinkPredictionDecoder.DotProduct, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture defining input/output sizes and layers.
hiddenDimintHidden dimension for intermediate layers (default: 64).
embeddingDimintDimension of node embeddings (default: 32).
numLayersintNumber of graph convolutional layers (default: 2).
dropoutRatedoubleDropout rate for regularization (default: 0.5).
decoderTypeLinkPredictionModel<T>.LinkPredictionDecoderMethod for combining node embeddings into edge scores.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training.
lossFunctionILossFunction<T>Optional loss function for training.
maxGradNormdoubleMaximum gradient norm for clipping (default: 1.0).
Remarks
For Beginners: Creating a link prediction model:
// Create architecture for friend prediction
var architecture = new NeuralNetworkArchitecture<double>(
InputType.OneDimensional,
NeuralNetworkTaskType.BinaryClassification,
NetworkComplexity.Simple,
inputSize: 128, // User features
outputSize: 1); // Edge score
// Create model with default layers
var model = new LinkPredictionModel<double>(architecture);
// Train on link prediction task
var history = model.TrainOnTask(task, epochs: 100, learningRate: 0.01);
Properties
DropoutRate
Gets the dropout rate for regularization.
public double DropoutRate { get; }
Property Value
EmbeddingDim
Gets the embedding dimension.
public int EmbeddingDim { get; }
Property Value
HiddenDim
Gets the hidden dimension size.
public int HiddenDim { get; }
Property Value
InputFeatures
Gets the number of input features per node.
public int InputFeatures { get; }
Property Value
NumLayers
Gets the number of GNN layers.
public int NumLayers { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass through the network.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient of loss with respect to output.
Returns
- Tensor<T>
Gradient with respect to input.
CreateNewInstance()
Creates a new instance of this network type for cloning or deserialization.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
EvaluateOnTask(LinkPredictionTask<T>)
Evaluates the model on test edges.
public double EvaluateOnTask(LinkPredictionTask<T> task)
Parameters
taskLinkPredictionTask<T>
Returns
Forward(Tensor<T>)
Performs a forward pass through the encoder network.
public Tensor<T> Forward(Tensor<T> nodeFeatures)
Parameters
nodeFeaturesTensor<T>Node feature tensor [num_nodes, input_features].
Returns
- Tensor<T>
Node embeddings [num_nodes, embedding_dim].
GetModelMetadata()
Gets metadata about this model for serialization and identification.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets all parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
Predict(Tensor<T>)
Makes a prediction using the trained network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor containing node features.
Returns
- Tensor<T>
The node embeddings tensor.
PredictEdges(Tensor<T>)
Computes edge scores for given node pairs.
public Tensor<T> PredictEdges(Tensor<T> edges)
Parameters
edgesTensor<T>Edge tensor of shape [num_edges, 2] where each row is [source, target].
Returns
- Tensor<T>
Edge scores of shape [num_edges].
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
SetAdjacencyMatrix(Tensor<T>)
Sets the adjacency matrix for all graph layers in the model.
public void SetAdjacencyMatrix(Tensor<T> adjacencyMatrix)
Parameters
adjacencyMatrixTensor<T>The graph adjacency matrix.
Train(Tensor<T>, Tensor<T>)
Trains the network on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input node features.
expectedOutputTensor<T>The expected output (edge scores).
TrainOnTask(LinkPredictionTask<T>, int, double)
Trains the model on a link prediction task.
public Dictionary<string, List<double>> TrainOnTask(LinkPredictionTask<T> task, int epochs, double learningRate = 0.01)
Parameters
taskLinkPredictionTask<T>The link prediction task with graph data and edge splits.
epochsintNumber of training epochs.
learningRatedoubleLearning rate for optimization.
Returns
- Dictionary<string, List<double>>
Training history with loss and metrics per epoch.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.