Class SpiralNet<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Implements the SpiralNet++ architecture for mesh-based deep learning.
public class SpiralNet<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
SpiralNet<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
SpiralNet++ processes 3D meshes by applying convolutions along spiral sequences of vertex neighbors. This creates translation-equivariant operations on irregular mesh structures without requiring mesh registration.
For Beginners: SpiralNet++ is designed for learning from 3D mesh data.
Key concepts:
- Mesh: A 3D surface made of vertices connected by edges/triangles
- Spiral ordering: A consistent way to visit vertex neighbors (like a clock hand)
- Spiral convolution: Apply weights to neighbors in spiral order
How it works:
- For each vertex, define a spiral ordering of its neighbors
- Gather neighbor features in spiral order
- Apply learned weights to the gathered features
- Pool vertices to create hierarchical representations
- Classify or segment the mesh
Applications:
- 3D face reconstruction and expression recognition
- Human body shape analysis
- Medical surface analysis (organs, bones)
- CAD model classification
Reference: "SpiralNet++: A Fast and Highly Efficient Mesh Convolution Operator" by Gong et al.
Constructors
SpiralNet()
Initializes a new instance of the SpiralNet<T> class with default options.
public SpiralNet()
Remarks
Creates a SpiralNet with default configuration suitable for common mesh tasks.
SpiralNet(SpiralNetOptions, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Initializes a new instance of the SpiralNet<T> class with specified options.
public SpiralNet(SpiralNetOptions options, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
optionsSpiralNetOptionsConfiguration options for the SpiralNet.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>The optimizer for training. Defaults to Adam if null.
lossFunctionILossFunction<T>The loss function. Defaults based on task type if null.
Exceptions
- ArgumentNullException
Thrown when options is null.
SpiralNet(int, int, int, ILossFunction<T>?)
Initializes a new instance of the SpiralNet<T> class with simple parameters.
public SpiralNet(int numClasses, int inputFeatures = 3, int spiralLength = 9, ILossFunction<T>? lossFunction = null)
Parameters
numClassesintNumber of output classes for classification.
inputFeaturesintNumber of input features per vertex. Default is 3.
spiralLengthintLength of spiral sequences. Default is 9.
lossFunctionILossFunction<T>The loss function. Defaults based on task type if null.
Properties
ConvChannels
Gets the channel configuration for spiral convolution layers.
public int[] ConvChannels { get; }
Property Value
- int[]
InputFeatures
Gets the number of input features per vertex.
public int InputFeatures { get; }
Property Value
NumClasses
Gets the number of output classes for classification.
public int NumClasses { get; }
Property Value
SpiralLength
Gets the spiral sequence length.
public int SpiralLength { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass to compute gradients.
public Tensor<T> Backward(Tensor<T> lossGradient)
Parameters
lossGradientTensor<T>Gradient of the loss with respect to network output.
Returns
- Tensor<T>
Gradient with respect to input.
CreateNewInstance()
Creates a new instance for cloning.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
New SpiralNet instance.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderBinary reader.
Remarks
This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.
For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.
Forward(Tensor<T>)
Performs a forward pass through the network.
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Vertex features tensor with shape [numVertices, InputFeatures].
Returns
- Tensor<T>
Classification logits with shape [NumClasses].
Exceptions
- InvalidOperationException
Thrown when spiral indices are not set.
GetModelMetadata()
Gets metadata about this model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
Model metadata.
InitializeLayers()
Initializes the layers of the SpiralNet network.
protected override void InitializeLayers()
Remarks
If the architecture provides custom layers, those are used. Otherwise, default layers are created using CreateDefaultSpiralNetLayers(NeuralNetworkArchitecture<T>, int, int, int[]?, double[]?, int[]?, bool, double, bool).
Predict(Tensor<T>)
Generates predictions for the given input.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>Vertex features tensor.
Returns
- Tensor<T>
Classification logits.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
PredictClass(Tensor<T>, int[,])
Predicts the class for a single mesh.
public int PredictClass(Tensor<T> meshFeatures, int[,] meshSpiralIndices)
Parameters
meshFeaturesTensor<T>Vertex features tensor.
meshSpiralIndicesint[,]Spiral indices for the mesh.
Returns
- int
Predicted class index.
PredictProbabilities(Tensor<T>, int[,])
Computes class probabilities for a single mesh using softmax.
public Vector<T> PredictProbabilities(Tensor<T> meshFeatures, int[,] meshSpiralIndices)
Parameters
meshFeaturesTensor<T>Vertex features tensor.
meshSpiralIndicesint[,]Spiral indices for the mesh.
Returns
- Vector<T>
Probability distribution over classes.
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterBinary writer.
Remarks
This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.
For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.
SetMultiResolutionSpiralIndices(List<int[,]>)
Sets spiral indices for multiple resolution levels (for hierarchical processing).
public void SetMultiResolutionSpiralIndices(List<int[,]> spiralIndicesPerLevel)
Parameters
Exceptions
- ArgumentNullException
Thrown when list is null.
- ArgumentException
Thrown when list is empty.
SetSpiralIndices(int[,])
Sets the spiral indices for the current mesh being processed.
public void SetSpiralIndices(int[,] spiralIndices)
Parameters
spiralIndicesint[,]A 2D array of shape [numVertices, SpiralLength] containing neighbor vertex indices in spiral order for each vertex.
Remarks
For Beginners: Before processing a mesh, you must define how vertices are connected in spiral order. This method sets that connectivity.
Exceptions
- ArgumentNullException
Thrown when spiralIndices is null.
Train(Tensor<T>, Tensor<T>)
Trains the network on a single batch.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>Vertex features tensor.
expectedOutputTensor<T>Ground truth labels.
Remarks
This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.
For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)
The network then:
- Makes a prediction based on the input
- Compares its prediction to the expected output
- Calculates how wrong it was (the loss)
- Adjusts its internal values to do better next time
After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.
Train(List<Tensor<T>>, List<int[,]>, List<int>, int, T)
Trains the network on mesh data.
public List<double> Train(List<Tensor<T>> meshFeatures, List<int[,]> spiralIndices, List<int> labels, int epochs, T learningRate)
Parameters
meshFeaturesList<Tensor<T>>List of vertex feature tensors for training meshes.
spiralIndicesList<int[,]>List of spiral indices for each training mesh.
labelsList<int>List of class labels for each mesh.
epochsintNumber of training epochs.
learningRateTLearning rate for optimization.
Returns
Exceptions
- ArgumentException
Thrown when input lists have mismatched lengths.
UpdateParameters(Vector<T>)
Updates network parameters using a flat parameter vector.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>Vector containing all parameters.
Remarks
For Beginners: During training, a neural network's internal values (parameters) get adjusted to improve its performance. This method allows you to update all those values at once by providing a complete set of new parameters.
This is typically used by optimization algorithms that calculate better parameter values based on training data.
UpdateParameters(T)
Updates network parameters using the optimizer.
public void UpdateParameters(T learningRate)
Parameters
learningRateTLearning rate for parameter updates.