Class CapsuleNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Capsule Network, a type of neural network that preserves spatial relationships between features.
public class CapsuleNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
CapsuleNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
A Capsule Network is a neural network architecture designed to address limitations of traditional convolutional neural networks. Instead of using scalar-output feature detectors (neurons), Capsule Networks use vector-output capsules. Each capsule's output vector represents the presence of an entity and its instantiation parameters (like position, orientation, and scale). This architecture helps to preserve hierarchical relationships between features, making it particularly effective for tasks requiring understanding of spatial relationships.
For Beginners: A Capsule Network is like a more advanced version of traditional neural networks.
Think of it this way:
- Traditional networks detect features like edges or textures, but lose information about how these features relate to each other
- Capsule Networks not only detect features, but also understand their relationships, orientations, and positions
- This is like the difference between recognizing individual puzzle pieces versus understanding how they fit together
For example, a traditional network might recognize an eye, a nose, and a mouth separately, but a Capsule Network can better understand that these features need to be in a specific arrangement to make a face. This makes Capsule Networks particularly good at recognizing objects from different angles or when parts are arranged differently.
Constructors
CapsuleNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?)
Initializes a new instance of the CapsuleNetwork<T> class with the specified architecture.
public CapsuleNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture configuration.
lossFunctionILossFunction<T>
Remarks
This constructor creates a new Capsule Network with the specified architecture. The architecture defines the structure of the network, including the input dimensions, number and types of layers, and output dimensions. The initialization process sets up the layers based on the provided architecture or creates default capsule network layers if none are specified.
For Beginners: This creates a new Capsule Network with your chosen settings.
When you create a Capsule Network:
- You provide an "architecture" that defines how the network is structured
- This includes information like how large the input is and what kinds of layers to use
- The constructor sets up the basic structure, but doesn't actually train the network yet
Think of it like setting up a blank canvas and easel before you start painting - you're just getting everything ready to use.
Properties
AuxiliaryLossWeight
Gets or sets the weight for reconstruction loss. Default is 0.0005 (standard value from original CapsNet paper).
public T AuxiliaryLossWeight { get; set; }
Property Value
- T
UseAuxiliaryLoss
Gets or sets whether to use auxiliary loss (reconstruction regularization) during training. Default is true as per Sabour et al. (2017) - required for proper CapsNet functionality.
public bool UseAuxiliaryLoss { get; set; }
Property Value
Methods
ComputeAuxiliaryLoss()
Computes the auxiliary loss for the CapsuleNetwork, which is the reconstruction regularization.
public T ComputeAuxiliaryLoss()
Returns
- T
The reconstruction loss value.
Remarks
The reconstruction loss encourages the digit capsules to encode instantiation parameters of the input. A decoder network uses the activity vector of the correct DigitCaps to reconstruct the input image. The reconstruction loss is the sum of squared differences between the input and reconstruction. This is scaled down by a factor (typically 0.0005) so it doesn't dominate the margin loss during training.
For Beginners: This calculates how well the network can reconstruct the original input from its capsule representation.
Reconstruction regularization:
- Takes the capsule outputs (compressed representation)
- Tries to recreate the original input from them
- Measures how different the reconstruction is from the original
- Encourages capsules to preserve important information about the input
Why this is important:
- Ensures capsules learn meaningful representations
- Prevents the network from learning arbitrary encodings
- Acts as a regularizer to improve generalization
- Helps capsules encode pose/instantiation parameters
This is similar to how an autoencoder works, but specifically designed for capsule networks.
ComputeReconstructionLoss(Tensor<T>, int?)
Computes the reconstruction loss for capsule network regularization.
public T ComputeReconstructionLoss(Tensor<T> input, int? trueLabel = null)
Parameters
inputTensor<T>The original input tensor.
trueLabelint?The true class label for masking. If null, uses argmax of capsule outputs.
Returns
- T
The reconstruction loss (MSE between reconstruction and input).
Remarks
This method implements the reconstruction loss from the original Capsule Networks paper (Sabour et al., 2017). During training, only the capsule corresponding to the true class is used for reconstruction (others are masked to zero). During inference, the capsule with the highest activation is used. The reconstruction helps regularize the network and ensures that capsule vectors encode meaningful instantiation parameters.
For Beginners: This method helps the network learn better by making it reconstruct the input.
How it works:
- Forward pass through the network to get capsule outputs
- Mask: Zero out all capsules except the correct one (or most active one)
- Reconstruction: Pass masked capsules through decoder layers
- Loss: Measure how different the reconstruction is from the original input
Why this helps:
- Forces capsules to encode useful information (position, rotation, etc.)
- Acts as regularization to prevent overfitting
- Improves interpretability of what capsules represent
The reconstruction loss is typically weighted much lower than the main classification loss (e.g., 0.0005 * reconstruction_loss) to avoid overwhelming the primary objective.
CreateNewInstance()
Creates a new instance of the capsule network model.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the capsule network model with the same configuration.
Remarks
This method creates a new instance of the capsule network model with the same configuration as the current instance. It is used internally during serialization/deserialization processes to create a fresh instance that can be populated with the serialized data. The new instance will have the same architecture and loss function as the original.
For Beginners: This method creates a copy of the network structure without copying the learned data.
Think of it like creating a blueprint of the capsule network:
- It copies the same overall design (architecture)
- It uses the same loss function to measure performance
- But it doesn't copy any of the learned values or weights
This is primarily used when saving or loading models, creating a framework that the saved parameters can be loaded into later. It's like creating an empty duplicate of the network's structure that can later be filled with the knowledge from the original network.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes Capsule Network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method loads the loss function used by the network. If deserialization fails, it defaults to using a MarginLoss.
GetAuxiliaryLossDiagnostics()
Gets diagnostic information about the auxiliary losses.
public Dictionary<string, string> GetAuxiliaryLossDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic information about CapsuleNetwork training.
Remarks
This method provides insights into CapsuleNetwork training dynamics, including: - Margin loss (primary classification loss) - Reconstruction loss (auxiliary regularization) - Total loss - Reconstruction weight
For Beginners: This gives you information to monitor CapsuleNetwork training health.
The diagnostics include:
- Margin Loss: The main classification loss from the capsule network
- Reconstruction Loss: How well the network can recreate inputs from capsules
- Total Loss: Combined loss used for training
- Reconstruction Weight: How much reconstruction influences training (usually small)
These values help you:
- Monitor training convergence
- Balance classification and reconstruction objectives
- Detect overfitting or underfitting
- Tune the reconstruction weight hyperparameter
GetDiagnostics()
Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.
public Dictionary<string, string> GetDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().
GetModelMetadata()
Retrieves metadata about the Capsule Network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the network.
Remarks
This method collects and returns various pieces of information about the network's structure and configuration. It includes details such as the input and output dimensions, the number of layers, and the types of layers used.
For Beginners: This is like creating a summary or overview of the network's structure.
Think of it as a quick reference guide that tells you:
- What kind of network it is (a Capsule Network)
- How big the input and output are
- How many layers the network has
- What types of layers are used
This information is useful for understanding the network's capabilities and for saving/loading the network.
InitializeLayers()
Initializes the layers of the Capsule Network based on the architecture.
protected override void InitializeLayers()
Remarks
This method sets up the layers of the Capsule Network. If custom layers are provided in the architecture, those layers are used. Otherwise, default capsule network layers are created based on the architecture's specifications. After adding the layers, the method validates that the custom layers are properly configured.
For Beginners: This method builds the actual structure of the network.
When initializing the layers:
- If you've specified your own custom layers, the network will use those
- If not, the network will create a standard set of layers that work well for most cases
- The method also checks that all layers are compatible with each other
This is like assembling the different sections of a factory production line - each layer processes the data and passes it to the next layer.
Predict(Tensor<T>)
Performs a forward pass through the Capsule Network to make a prediction.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to the network.
Returns
- Tensor<T>
The output tensor (prediction) from the network.
Remarks
This method passes the input tensor through each layer of the network in sequence. Each layer processes the output from the previous layer (or the input for the first layer) and produces an output that becomes the input for the next layer.
For Beginners: This is like passing a piece of information through a series of processing stations.
Imagine an assembly line:
- The input is the raw material
- Each layer is a workstation that modifies or processes the material
- The output is the final product after it has passed through all stations
In a Capsule Network, this process preserves and processes spatial relationships, allowing the network to understand complex structures in the input data.
SerializeNetworkSpecificData(BinaryWriter)
Serializes Capsule Network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method saves the loss function used by the network, allowing it to be reconstructed when the network is deserialized.
Train(Tensor<T>, Tensor<T>)
Trains the Capsule Network using the provided input and expected output.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor for training.
expectedOutputTensor<T>The expected output tensor for the given input.
Remarks
This method performs one training iteration: 1. It makes a prediction using the current network parameters. 2. Calculates the loss between the prediction and the expected output. 3. Computes the gradient of the loss with respect to the network parameters. 4. Updates the network parameters based on the computed gradient.
For Beginners: This is like a practice session where the network learns from its mistakes.
The process is similar to learning a new skill:
- You try to perform the task (make a prediction)
- You see how far off you were (calculate the loss)
- You figure out what you need to change to do better (compute the gradient)
- You adjust your approach based on what you learned (update parameters)
This process is repeated many times with different inputs to improve the network's performance.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the Capsule Network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing the parameters to update all layers with.
Remarks
This method distributes the provided parameter vector among all the layers in the network. Each layer receives a portion of the parameter vector corresponding to its number of parameters. The method keeps track of the starting index for each layer's parameters in the input vector.
For Beginners: This method updates the network's internal values during training.
When updating parameters:
- The input is a long list of numbers representing all values in the entire network
- The method divides this list into smaller chunks
- Each layer gets its own chunk of values
- The layers use these values to adjust their internal settings
Think of it like giving each department in a company their specific budget allocations from the overall company budget.