Class MeshPoolLayer<T>
- Namespace
- AiDotNet.NeuralNetworks.Layers
- Assembly
- AiDotNet.dll
Implements mesh pooling via edge collapse for MeshCNN-style networks.
public class MeshPoolLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
LayerBase<T>MeshPoolLayer<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
MeshPoolLayer reduces the number of edges in a mesh by collapsing edges based on learned importance scores. This is analogous to pooling in image CNNs but operates on the mesh structure.
For Beginners: Just like max pooling shrinks an image by combining pixels, mesh pooling shrinks a mesh by combining edges. The layer learns which edges are less important and removes them, simplifying the mesh while preserving important features.
Key concepts:
- Edge collapse: Remove an edge by merging its two vertices into one
- Importance score: Learned value indicating how important each edge is
- Target edges: Number of edges to keep after pooling
The process:
- Compute importance scores for all edges using current features
- Sort edges by importance (lowest first)
- Collapse least important edges until target count is reached
- Update adjacency information for remaining edges
Reference: "MeshCNN: A Network with an Edge" by Hanocka et al., SIGGRAPH 2019
Constructors
MeshPoolLayer(int, int, int)
Initializes a new instance of the MeshPoolLayer<T> class.
public MeshPoolLayer(int inputChannels, int targetEdges, int numNeighbors = 4)
Parameters
inputChannelsintNumber of input feature channels per edge.
targetEdgesintTarget number of edges after pooling.
numNeighborsintNumber of neighboring edges per edge. Default is 4.
Remarks
For Beginners: Creates a mesh pooling layer that reduces mesh complexity.
Example: If your mesh has 750 edges and you want to reduce it to 450 edges, set targetEdges=450. The layer will learn to remove the 300 least important edges.
Exceptions
- ArgumentOutOfRangeException
Thrown when parameters are non-positive.
Properties
InputChannels
Gets the number of input feature channels per edge.
public int InputChannels { get; }
Property Value
ParameterCount
Gets the total number of trainable parameters.
public override int ParameterCount { get; }
Property Value
RemainingEdgeIndices
Gets or sets the edge indices that remain after pooling.
public int[]? RemainingEdgeIndices { get; }
Property Value
- int[]
Remarks
This is set during the forward pass and can be used to track which edges were preserved through pooling.
SupportsJitCompilation
Gets a value indicating whether this layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
falsebecause mesh pooling requires dynamic graph operations.
SupportsTraining
Gets a value indicating whether this layer supports training.
public override bool SupportsTraining { get; }
Property Value
- bool
Always
trueas importance scores are learned.
TargetEdges
Gets the target number of edges after pooling.
public int TargetEdges { get; }
Property Value
Remarks
This determines how much the mesh is simplified. For example, pooling from 750 edges to 450 edges removes about 40% of the edges.
UpdatedAdjacency
Gets or sets the updated edge adjacency after pooling.
public int[,]? UpdatedAdjacency { get; }
Property Value
- int[,]
Remarks
After pooling, the edge adjacency must be updated to reflect the new connectivity of the reduced mesh.
Methods
Backward(Tensor<T>)
Performs the backward pass for mesh pooling using vectorized scatter operations.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient with respect to pooled output.
Returns
- Tensor<T>
Gradient with respect to input (sparse, only at kept edges).
Remarks
Uses Engine.TensorScatterAdd to efficiently scatter gradients back to their original positions. This is much faster than element-wise loops, especially on GPU.
Exceptions
- InvalidOperationException
Thrown when Forward has not been called.
BackwardGpu(IGpuTensor<T>)
Performs GPU-accelerated backward pass for mesh pooling.
public override IGpuTensor<T> BackwardGpu(IGpuTensor<T> outputGradient)
Parameters
outputGradientIGpuTensor<T>GPU tensor with gradient from next layer [numKept, InputChannels].
Returns
- IGpuTensor<T>
GPU tensor with input gradients [numEdges, InputChannels].
Remarks
The backward pass for mesh pooling scatters the output gradients back to their original positions in the input tensor. Uses GPU scatter-add operation for efficiency.
Also computes gradient for the importance weights using matrix operations: - Gathers kept edge features - Computes weighted gradient through the selection operation
Exceptions
- InvalidOperationException
Thrown when ForwardGpu has not been called in training mode.
Clone()
Creates a deep copy of the layer.
public override LayerBase<T> Clone()
Returns
- LayerBase<T>
A new instance with identical configuration and parameters.
Deserialize(BinaryReader)
Deserializes the layer from a binary stream.
public override void Deserialize(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to deserialize from.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the layer as a computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input nodes.
Returns
- ComputationNode<T>
The output computation node representing the mesh pooling operation.
Remarks
Mesh pooling is approximated in the computation graph using a learned attention-weighted aggregation. The edge importance scores are computed and used to weight the features before reduction, enabling gradient flow through the pooling operation.
Exceptions
- ArgumentNullException
Thrown when inputNodes is null.
- InvalidOperationException
Thrown when layer is not properly initialized.
Forward(Tensor<T>)
Performs the forward pass of mesh pooling.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Edge features tensor of shape [numEdges, InputChannels].
Returns
- Tensor<T>
Pooled edge features of shape [TargetEdges, InputChannels].
Remarks
This method requires edge adjacency to be set via SetEdgeAdjacency(int[,]) before calling.
Exceptions
- ArgumentException
Thrown when input has invalid shape.
- InvalidOperationException
Thrown when edge adjacency is not set.
ForwardGpu(params IGpuTensor<T>[])
Performs GPU-accelerated forward pass for mesh pooling.
public override IGpuTensor<T> ForwardGpu(params IGpuTensor<T>[] inputs)
Parameters
inputsIGpuTensor<T>[]Input GPU tensors (uses first input).
Returns
- IGpuTensor<T>
GPU-resident output tensor with pooled features.
Remarks
Uses GPU for importance score computation (GEMM) and feature gathering. Sorting remains on CPU as it requires dynamic branching that is inefficient on GPU. The operation is: 1. scores = input @ importanceWeights (GPU GEMM) 2. sortedIndices = sort(scores) (CPU - inherently sequential) 3. output = gather(input, topKIndices) (GPU Gather)
GetBiases()
Gets the bias tensor (null for this layer).
public override Tensor<T> GetBiases()
Returns
- Tensor<T>
Null as this layer has no biases.
GetParameters()
Gets all trainable parameters as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
Vector containing importance weights.
GetWeights()
Gets the importance weights tensor.
public override Tensor<T> GetWeights()
Returns
- Tensor<T>
The importance weights.
ResetState()
Resets the cached state from forward/backward passes.
public override void ResetState()
Serialize(BinaryWriter)
Serializes the layer to a binary stream.
public override void Serialize(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to serialize to.
SetEdgeAdjacency(int[,])
Sets the edge adjacency information for the current mesh.
public void SetEdgeAdjacency(int[,] edgeAdjacency)
Parameters
edgeAdjacencyint[,]A 2D array of shape [numEdges, NumNeighbors] containing neighbor edge indices.
Exceptions
- ArgumentNullException
Thrown when edgeAdjacency is null.
SetParameters(Vector<T>)
Sets all trainable parameters from a vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>Vector containing importance weights.
UpdateParameters(T)
Updates the layer parameters using computed gradients.
public override void UpdateParameters(T learningRate)
Parameters
learningRateTThe learning rate for gradient descent.