Class GraphGenerationModel<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Graph Generation Model using Variational Autoencoder (VAE) architecture.
public class GraphGenerationModel<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GraphGenerationModel<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Graph generation models learn to generate new graph structures from latent representations. This implementation uses a Variational Graph Autoencoder (VGAE) approach that learns a latent distribution of graph structures and can sample new valid graphs.
For Beginners: Graph generation creates new graphs similar to training data.
How it works:
- Encoder: Compress graph structure into latent space using GNN
- Latent space: Learn probabilistic representation (mean and variance)
- Decoder: Reconstruct graph from latent representation
- Sampling: Generate new graphs by sampling from latent space
Example - Drug Discovery:
- Train on known drug molecules
- Learn latent representation of valid molecular structures
- Generate new candidate molecules by sampling
- Filter candidates by predicted properties
Key Components:
- GNN Encoder: Maps node features to latent space
- Variational Layer: Learns mean and log-variance for each node
- Inner Product Decoder: Reconstructs adjacency matrix
- Reparameterization: Enables gradient flow through sampling
Loss Function:
- Reconstruction Loss: How well we reconstruct the adjacency matrix
- KL Divergence: Regularization to keep latent space well-structured
Applications:
- Molecular design and drug discovery
- Social network generation
- Circuit design
- Protein structure generation
Constructors
GraphGenerationModel(int, int, int, int, int, GraphGenerationType, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the GraphGenerationModel<T> class.
public GraphGenerationModel(int inputFeatures = 16, int hiddenDim = 32, int latentDim = 16, int numEncoderLayers = 2, int maxNodes = 100, GraphGenerationType generationType = GraphGenerationType.VariationalAutoencoder, double klWeight = 1, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
inputFeaturesintNumber of input features per node (default: 16).
hiddenDimintHidden dimension for encoder layers (default: 32).
latentDimintDimension of latent space (default: 16).
numEncoderLayersintNumber of GNN encoder layers (default: 2).
maxNodesintMaximum number of nodes for graph generation (default: 100).
generationTypeGraphGenerationTypeType of graph generation approach (default: VariationalAutoencoder).
klWeightdoubleWeight for KL divergence term (default: 1.0).
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training (default: AdamOptimizer).
lossFunctionILossFunction<T>Optional loss function for training (default: BinaryCrossEntropyLoss).
maxGradNormdoubleMaximum gradient norm for clipping (default: 1.0).
Remarks
For Beginners: Creating a graph generation model:
// Create model with all defaults
var model = new GraphGenerationModel<double>();
// Create model for molecular generation with custom settings
var model = new GraphGenerationModel<double>(
inputFeatures: 9, // Atom features
hiddenDim: 32, // Hidden layer size
latentDim: 16, // Latent space dimension
numEncoderLayers: 2, // 2 GNN encoder layers
maxNodes: 50, // Maximum 50 atoms per molecule
klWeight: 0.5); // KL divergence weight
// Train on molecular graphs
model.Train(molecules, adjacencyMatrices, epochs: 100);
// Generate new molecules
var newMolecules = model.Generate(numSamples: 10, numNodes: 20);
Properties
GenerationType
Gets the type of graph generation.
public GraphGenerationType GenerationType { get; }
Property Value
HiddenDim
Gets the hidden dimension for encoder layers.
public int HiddenDim { get; }
Property Value
KLWeight
KL divergence weight for balancing reconstruction and regularization.
public double KLWeight { get; set; }
Property Value
LatentDim
Gets the latent dimension size.
public int LatentDim { get; }
Property Value
MaxNodes
Gets the maximum number of nodes for graph generation.
public int MaxNodes { get; }
Property Value
NumEncoderLayers
Gets the number of encoder layers.
public int NumEncoderLayers { get; }
Property Value
NumLayers
Gets the number of layers in the model.
public int NumLayers { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs backward pass through the model.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient of the loss with respect to reconstructed adjacency.
Returns
- Tensor<T>
Gradient with respect to input features.
ComputeLoss(Tensor<T>, Tensor<T>)
Computes the ELBO loss (reconstruction + KL divergence).
public T ComputeLoss(Tensor<T> reconstructed, Tensor<T> original)
Parameters
reconstructedTensor<T>Reconstructed adjacency matrix.
originalTensor<T>Original adjacency matrix.
Returns
- T
Total ELBO loss value.
CreateNewInstance()
Creates a new instance of this model type.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
Decode(Tensor<T>)
Decodes latent representations to reconstruct the adjacency matrix.
public Tensor<T> Decode(Tensor<T> latent)
Parameters
latentTensor<T>Latent representation tensor of shape [numNodes, latentDim].
Returns
- Tensor<T>
Reconstructed adjacency matrix of shape [numNodes, numNodes].
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
Encode(Tensor<T>, Tensor<T>)
Encodes node features into latent space representations.
public (Tensor<T> mean, Tensor<T> logVar) Encode(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor of shape [numNodes, inputFeatures].
adjacencyMatrixTensor<T>Adjacency matrix of shape [numNodes, numNodes].
Returns
- (Tensor<T> grad1, Tensor<T> grad2)
Tuple of (mean, log_variance) tensors for the latent distribution.
Forward(Tensor<T>, Tensor<T>)
Performs a complete forward pass: encode, sample, decode.
public Tensor<T> Forward(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix.
Returns
- Tensor<T>
Reconstructed adjacency matrix.
Generate(int, int, Tensor<T>?, double)
Generates new graphs by sampling from the latent space with optional conditioning.
public Tensor<T> Generate(int numSamples, int numNodes, Tensor<T>? conditioningInput = null, double threshold = 0.5)
Parameters
numSamplesintNumber of graphs to generate.
numNodesintNumber of nodes in generated graphs.
conditioningInputTensor<T>Optional conditioning input tensor.
thresholddoubleEdge probability threshold (default: 0.5).
Returns
- Tensor<T>
Generated adjacency matrix tensor.
Generate(int, int, double)
Generates new graphs by sampling from the latent space.
public List<Tensor<T>> Generate(int numNodes, int numSamples = 1, double threshold = 0.5)
Parameters
numNodesintNumber of nodes in generated graphs.
numSamplesintNumber of graphs to generate.
thresholddoubleEdge probability threshold (default: 0.5).
Returns
- List<Tensor<T>>
List of generated adjacency matrices.
Remarks
For Beginners: Generating new graphs:
// Generate 10 new molecular graphs with 20 atoms each
var newGraphs = model.Generate(numNodes: 20, numSamples: 10, threshold: 0.5);
// Each graph is an adjacency matrix where 1 indicates an edge
foreach (var adj in newGraphs)
{
// Process generated molecule structure
}
GetModelMetadata()
Gets metadata about this model.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameterCount()
Gets the total number of trainable parameters in the network.
public int GetParameterCount()
Returns
GetParameters()
Gets all parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
InitializeLayers()
Initializes the layers of the neural network based on the provided architecture.
protected override void InitializeLayers()
Interpolate(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, int)
Interpolates between two graphs in latent space.
public List<Tensor<T>> Interpolate(Tensor<T> graph1Features, Tensor<T> graph1Adj, Tensor<T> graph2Features, Tensor<T> graph2Adj, int numSteps = 5)
Parameters
graph1FeaturesTensor<T>Node features of first graph.
graph1AdjTensor<T>Adjacency matrix of first graph.
graph2FeaturesTensor<T>Node features of second graph.
graph2AdjTensor<T>Adjacency matrix of second graph.
numStepsintNumber of interpolation steps.
Returns
- List<Tensor<T>>
List of interpolated adjacency matrices.
Predict(Tensor<T>)
Makes a prediction (generates a graph) using the trained model.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
Reparameterize(Tensor<T>, Tensor<T>)
Samples from the latent distribution using the reparameterization trick.
public Tensor<T> Reparameterize(Tensor<T> mean, Tensor<T> logVar)
Parameters
meanTensor<T>Mean of the latent distribution.
logVarTensor<T>Log-variance of the latent distribution.
Returns
- Tensor<T>
Sampled latent representation.
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
Train(Tensor<T>, Tensor<T>)
Trains the model on a single batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>expectedOutputTensor<T>
Train(Tensor<T>, Tensor<T>, int, double)
Trains the model on graph data.
public void Train(Tensor<T> nodeFeatures, Tensor<T> adjacencyMatrix, int epochs = 200, double learningRate = 0.01)
Parameters
nodeFeaturesTensor<T>Node feature tensor.
adjacencyMatrixTensor<T>Adjacency matrix (target for reconstruction).
epochsintNumber of training epochs.
learningRatedoubleLearning rate.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.