Class DigitCapsuleLayer<T>
- Namespace
- AiDotNet.NeuralNetworks.Layers
- Assembly
- AiDotNet.dll
Represents a digit capsule layer that implements the dynamic routing algorithm between capsules.
public class DigitCapsuleLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
LayerBase<T>DigitCapsuleLayer<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
A digit capsule layer extends the concept of traditional neural networks by using groups of neurons (capsules) that encapsulate various properties of entities. This implementation is based on the CapsNet architecture proposed by Hinton et al., which uses a dynamic routing algorithm to determine how lower-level capsules should send their output to higher-level capsules.
For Beginners: A capsule layer is a special type of neural network layer that groups neurons together.
Think of regular neural networks as looking at individual puzzle pieces (like detecting edges or corners). A capsule network looks at how these pieces fit together to form objects.
For example, in image recognition:
- Regular neurons might detect a wheel, a window, and a door
- Capsules understand that these parts can make up a car, and how those parts relate to each other
This layer specifically handles digit recognition, taking information from previous capsule layers and determining which digit is most likely present in the input.
Constructors
DigitCapsuleLayer(int, int, int, int, int)
Initializes a new instance of the DigitCapsuleLayer<T> class.
public DigitCapsuleLayer(int inputCapsules, int inputCapsuleDimension, int numClasses, int outputCapsuleDimension, int routingIterations)
Parameters
inputCapsulesintThe number of capsules in the input layer.
inputCapsuleDimensionintThe dimension of each input capsule vector.
numClassesintThe number of classes (output capsules) that this layer can identify.
outputCapsuleDimensionintThe dimension of each output capsule vector.
routingIterationsintThe number of iterations to use in the dynamic routing algorithm.
Remarks
This constructor creates a new digit capsule layer with the specified parameters. It initializes the weight tensor with small random values scaled according to the dimensions of the input and output capsules. The layer uses a squash activation function, which is specific to capsule networks.
For Beginners: This sets up the capsule layer with the specific details it needs.
When creating a digit capsule layer, you need to specify:
- How many input feature groups there are (inputCapsules)
- How detailed each input feature is (inputCapsuleDimension)
- How many categories to classify into (numClasses)
- How detailed each output prediction should be (outputCapsuleDimension)
- How carefully to analyze connections between inputs and outputs (routingIterations)
For example, to recognize handwritten digits, you might use 10 output classes (digits 0-9) with a moderate number of routing iterations (3-5) for good performance.
Properties
SupportsGpuExecution
Gets whether this layer has a GPU execution implementation for inference.
protected override bool SupportsGpuExecution { get; }
Property Value
Remarks
Override this to return true when the layer implements ForwardGpu(params IGpuTensor<T>[]). The actual CanExecuteOnGpu property combines this with engine availability.
For Beginners: This flag indicates if the layer has GPU code for the forward pass. Set this to true in derived classes that implement ForwardGpu.
SupportsJitCompilation
Gets a value indicating whether this layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
truebecause DigitCapsuleLayer uses dynamic routing with a fixed number of iterations that can be unrolled into a static computation graph.
SupportsTraining
Gets a value indicating whether this layer supports training.
public override bool SupportsTraining { get; }
Property Value
- bool
Always
truebecause this layer has trainable parameters (weights).
Remarks
This property indicates that the digit capsule layer supports training through backpropagation. The layer has trainable weights that are updated during the training process.
For Beginners: This property tells you that this layer can learn from data.
A value of true means:
- The layer can adjust its internal values during training
- It will improve its performance as it sees more data
- It has weights that are updated to make better predictions over time
Methods
Backward(Tensor<T>)
Performs the backward pass of the digit capsule layer to compute gradients.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the layer's output.
Returns
- Tensor<T>
The gradient of the loss with respect to the layer's input.
Remarks
This method implements the backward pass (backpropagation) of the digit capsule layer. It computes the gradients of the loss with respect to the layer's weights and inputs, which are used to update the weights during training.
For Beginners: This is where the layer learns from its mistakes during training.
The backward pass:
- Receives information about how the network's prediction was wrong
- Calculates how each weight contributed to this error
- Determines how to adjust the weights to reduce the error next time
- Passes error information back to previous layers
It's like figuring out which ingredients in a recipe need to be adjusted after tasting the finished dish, then sharing that feedback with those who prepared the ingredients.
Exceptions
- InvalidOperationException
Thrown when backward is called before forward.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the layer's computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the layer's operation.
Remarks
This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.
For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.
To support JIT compilation, a layer must:
- Implement this method to export its computation graph
- Set SupportsJitCompilation to true
- Use ComputationNode and TensorOperations to build the graph
All layers are required to implement this method, even if they set SupportsJitCompilation = false.
Forward(Tensor<T>)
Performs the forward pass of the digit capsule layer using dynamic routing.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process.
Returns
- Tensor<T>
The output tensor after capsule routing.
Remarks
This method implements the forward pass of the digit capsule layer using the dynamic routing algorithm. It transforms input capsules into predictions for each class, then iteratively refines the routing coefficients to determine how strongly each input capsule should connect to each output capsule.
For Beginners: This is where the layer makes its predictions based on the input data.
The forward pass works in these steps:
- Transform each input feature into predictions for each possible class
- Start with equal connection strengths between all inputs and outputs
- For several iterations:
- Calculate how strongly each input connects to each output
- Update these connections based on how well inputs agree with outputs
- Recalculate the output predictions using the updated connections
This process is like having experts (input capsules) vote on different outcomes (classes), then gradually giving more weight to experts who agree with the consensus for each outcome.
ForwardGpu(params IGpuTensor<T>[])
Performs GPU-accelerated forward pass through the digit capsule layer.
public override IGpuTensor<T> ForwardGpu(params IGpuTensor<T>[] inputs)
Parameters
inputsIGpuTensor<T>[]GPU-resident input tensors.
Returns
- IGpuTensor<T>
GPU-resident output tensor after capsule routing.
Remarks
This method implements the forward pass using GPU-resident operations for the dynamic routing iterations. Uses CapsulePredictionsGpu for the prediction transform and DynamicRoutingGpu for the routing iterations, keeping all data on GPU.
GetParameters()
Gets all trainable parameters of the layer as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all trainable parameters.
Remarks
This method retrieves all trainable parameters (weights) of the layer as a single vector. This is useful for optimization algorithms that operate on all parameters at once, or for saving and loading model weights.
For Beginners: This method collects all the layer's learnable values into a single list.
The parameters:
- Are all the weight values that the network learns
- Are flattened into a single long list (vector)
- Can be saved to disk or loaded from a previous training session
This allows you to:
- Save a trained model for later use
- Transfer a model's knowledge to another identical model
- Share trained models with others
ResetState()
Resets the internal state of the layer.
public override void ResetState()
Remarks
This method resets the internal state of the layer by clearing all cached values from forward and backward passes. This is useful when starting to process a new batch of data or when implementing stateful recurrent networks.
For Beginners: This method clears the layer's memory to start fresh.
When resetting the state:
- All stored inputs, outputs, and intermediate values are cleared
- The layer forgets previous data it processed
- This prepares it for processing new, unrelated data
It's like wiping a whiteboard clean before starting a new calculation. This ensures that information from one batch of data doesn't affect the processing of another, unrelated batch.
SetParameters(Vector<T>)
Sets the trainable parameters of the layer from a single vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters to set.
Remarks
This method sets all trainable parameters (weights) of the layer from a single vector. This is useful for loading saved model weights or for implementing optimization algorithms that operate on all parameters at once.
For Beginners: This method updates all the layer's learnable values from a provided list.
When setting parameters:
- The input must be a vector with the exact right length
- Each value in the vector corresponds to a specific weight in the layer
- This allows loading previously trained weights
Use cases include:
- Restoring a saved model
- Using pre-trained weights
- Testing specific weight configurations
The method throws an error if the provided vector doesn't contain exactly the right number of values.
Exceptions
- ArgumentException
Thrown when the parameters vector has incorrect length.
UpdateParameters(T)
Updates the layer's weights using the calculated gradients and the specified learning rate.
public override void UpdateParameters(T learningRate)
Parameters
learningRateTThe learning rate to use for the parameter updates.
Remarks
This method updates the layer's weights based on the gradients calculated during the backward pass. The learning rate determines the size of the parameter updates. Smaller learning rates lead to more stable but slower training, while larger learning rates can lead to faster but potentially unstable training.
For Beginners: This method actually adjusts the weights based on what was learned.
After figuring out what changes need to be made:
- The network adjusts each weight by a small amount
- The learning rate controls how big this adjustment is
- Too small: learning happens very slowly
- Too large: learning becomes unstable
It's like adjusting a recipe based on taste - you don't want to add too much salt at once, but you also don't want to add just one grain at a time.
Exceptions
- InvalidOperationException
Thrown when update is called before backward.