Class RBMLayer<T>
- Namespace
- AiDotNet.NeuralNetworks.Layers
- Assembly
- AiDotNet.dll
Represents a Restricted Boltzmann Machine (RBM) layer for neural networks.
public class RBMLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
LayerBase<T>RBMLayer<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
An RBM layer is a stochastic neural network layer that learns a probability distribution over its inputs. It consists of a visible layer and a hidden layer with no connections between nodes within the same layer.
For Beginners: An RBM layer is like a feature detector that can learn patterns in data.
Imagine you have a set of movie ratings:
- The visible layer represents the actual ratings
- The hidden layer represents abstract features (e.g., "likes action", "prefers comedy")
- The RBM learns to connect ratings to these abstract features
RBM layers are useful for:
- Finding underlying patterns in data
- Reducing the dimensionality of data
- Initializing weights for deep neural networks
Constructors
RBMLayer(int, int, IActivationFunction<T>?)
public RBMLayer(int visibleUnits, int hiddenUnits, IActivationFunction<T>? scalarActivation = null)
Parameters
visibleUnitsinthiddenUnitsintscalarActivationIActivationFunction<T>
RBMLayer(int, int, IVectorActivationFunction<T>?)
public RBMLayer(int visibleUnits, int hiddenUnits, IVectorActivationFunction<T>? vectorActivation = null)
Parameters
visibleUnitsinthiddenUnitsintvectorActivationIVectorActivationFunction<T>
Properties
ParameterCount
Gets the total number of trainable parameters in the layer.
public override int ParameterCount { get; }
Property Value
SupportsGpuExecution
Gets a value indicating whether this layer supports GPU execution.
protected override bool SupportsGpuExecution { get; }
Property Value
SupportsJitCompilation
Gets a value indicating whether this layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
Always
true. RBM uses mean-field inference for JIT compilation.
Remarks
JIT compilation for RBM uses mean-field inference instead of stochastic sampling. This provides a deterministic forward pass where hidden probabilities are computed directly using sigmoid(W*v + b) without sampling. Training still uses Contrastive Divergence with sampling, but inference/forward pass can be JIT compiled.
SupportsTraining
Indicates whether this layer supports training.
public override bool SupportsTraining { get; }
Property Value
Methods
Backward(Tensor<T>)
Computes the backward pass of the RBM layer.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the output.
Returns
- Tensor<T>
The gradient of the loss with respect to the input.
Remarks
This method implements the backward pass of the RBM layer. In the context of RBM training, it is used to compute a reconstruction of the visible units given the hidden unit activations. This is part of the contrastive divergence training algorithm.
For Beginners: This method reconstructs the input based on the detected patterns.
During the backward pass:
- The RBM takes the pattern detections (hidden units)
- It tries to recreate the original input that would produce these patterns
- The result is the RBM's "imagination" of what the input should look like
- Both the reconstruction and its corresponding pattern detections are saved for training
This is like the RBM saying "if these patterns exist, this is what the input should look like." The difference between this reconstruction and the original input drives the learning process.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the layer's computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the layer's operation.
Remarks
This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.
For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.
To support JIT compilation, a layer must:
- Implement this method to export its computation graph
- Set SupportsJitCompilation to true
- Use ComputationNode and TensorOperations to build the graph
All layers are required to implement this method, even if they set SupportsJitCompilation = false.
Forward(Tensor<T>)
Computes the forward pass of the RBM layer.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>The input tensor containing visible unit activations.
Returns
- Tensor<T>
The output tensor containing hidden unit activations.
Remarks
This method implements the forward pass of the RBM layer. It takes a tensor of visible unit activations, computes the probability of each hidden unit being active, and returns these probabilities as the output tensor. It also stores the input and output for use in training.
For Beginners: This method calculates what patterns the RBM detects in the input data.
During the forward pass:
- The RBM receives input data (like an image or set of ratings)
- It calculates how strongly each hidden unit (pattern detector) should activate
- The result shows which patterns were detected in the input
- The original input and pattern detections are saved for training
This is like the RBM saying "based on this input, here are the patterns I can see in it."
ForwardGpu(params IGpuTensor<T>[])
Performs the forward pass on GPU using FusedLinearGpu with Sigmoid activation.
public override IGpuTensor<T> ForwardGpu(params IGpuTensor<T>[] inputs)
Parameters
inputsIGpuTensor<T>[]The GPU input tensors.
Returns
- IGpuTensor<T>
The GPU output tensor.
GetParameters()
Gets all trainable parameters of the layer as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all trainable parameters.
Remarks
This method collects all trainable parameters of the RBM (weights, visible biases, and hidden biases) into a single vector. The parameters are arranged in the order: weights (row-major), visible biases, hidden biases.
For Beginners: This method packs all the RBM's learnable values into one list.
The returned vector contains:
- All weight values (connections between visible and hidden units)
- All visible bias values (default preferences for visible units)
- All hidden bias values (default sensitivities for hidden units)
This is useful for:
- Saving the RBM's state to a file
- Loading a previously trained RBM
- Using optimization algorithms that work on all parameters at once
Think of it as taking a snapshot of everything the RBM has learned.
ResetState()
Resets the internal state of the layer.
public override void ResetState()
Remarks
This method clears the cached data used during training (both CD and backprop). While RBMs don't maintain state between passes in the same way as recurrent networks, this implementation does cache intermediate values for training purposes.
For Beginners: This method clears the RBM's "memory" of previous inputs.
When you call this method:
- The RBM forgets the last data it saw
- It clears internal storage used during training
- This ensures each batch of data is processed independently
This is useful when you want to start fresh, such as when:
- Beginning training on a new dataset
- Switching from training mode to evaluation mode
- Processing unrelated sequences of data
TrainWithContrastiveDivergence(Vector<T>, T, int)
Trains the RBM using contrastive divergence with the given data.
public void TrainWithContrastiveDivergence(Vector<T> input, T learningRate, int kSteps = 1)
Parameters
inputVector<T>The input data vector.
learningRateTThe learning rate for parameter updates.
kStepsintThe number of Gibbs sampling steps (typically 1).
Remarks
This method implements the contrastive divergence (CD) algorithm for training the RBM. It performs 'kSteps' of Gibbs sampling to approximate the model's distribution and updates weights and biases to make the model distribution closer to the data distribution.
For Beginners: This method teaches the RBM to recognize patterns in data.
Contrastive divergence works by comparing:
- How units activate with real data ("reality")
- How units activate with the RBM's own generated data ("imagination")
The training process:
- First computes probabilities and samples from the visible to hidden layer
- Then reconstructs the visible layer from the hidden layer
- Repeats this Gibbs sampling chain for kSteps iterations
- Finally updates the weights and biases based on the difference between the original data correlations and the generated data correlations
The standard approach uses CD-1 (k=1), which works surprisingly well in practice despite being a rough approximation.
UpdateParameters(T)
Updates the layer's parameters using either standard backpropagation or contrastive divergence.
public override void UpdateParameters(T learningRate)
Parameters
learningRateTThe learning rate for parameter updates.
Remarks
This method handles two training modes:
- Discriminative training (backprop): Uses gradients computed by BackwardViaAutodiff.
- Generative training (CD-k): Uses statistics from the Gibbs sampling chain.