Class BayesianDenseLayer<T>
- Namespace
- AiDotNet.UncertaintyQuantification.Layers
- Assembly
- AiDotNet.dll
Implements a Bayesian dense (fully-connected) layer using variational inference.
public class BayesianDenseLayer<T> : LayerBase<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>, IDisposable, IBayesianLayer<T>
Type Parameters
TThe numeric type used for computations (e.g., float, double).
- Inheritance
-
LayerBase<T>BayesianDenseLayer<T>
- Implements
-
ILayer<T>
- Inherited Members
Remarks
For Beginners: A Bayesian Dense Layer is similar to a regular dense layer, but instead of having fixed weights, it learns probability distributions over weights.
This is based on the "Bayes by Backprop" algorithm which uses variational inference to approximate the true posterior distribution of weights.
The layer maintains two sets of parameters for each weight:
- Mean (μ): The average value of the weight
- Standard deviation (σ): How much the weight varies
During forward passes, weights are sampled from these distributions, allowing the network to express uncertainty in its predictions.
Constructors
BayesianDenseLayer(int, int, IActivationFunction<T>?, double, int?)
Initializes a new instance of the BayesianDenseLayer<T> class with a custom activation.
public BayesianDenseLayer(int inputSize, int outputSize, IActivationFunction<T>? scalarActivation, double priorSigma = 1, int? randomSeed = null)
Parameters
inputSizeintThe number of input features.
outputSizeintThe number of output features.
scalarActivationIActivationFunction<T>The activation function to apply.
priorSigmadoubleThe standard deviation of the prior distribution (default: 1.0).
randomSeedint?Optional random seed for reproducible sampling.
Remarks
For Beginners: This overload lets you choose what activation the layer uses, while still keeping Bayesian uncertainty for the weights.
BayesianDenseLayer(int, int, double, int?)
Initializes a new instance of the BayesianDenseLayer class.
public BayesianDenseLayer(int inputSize, int outputSize, double priorSigma = 1, int? randomSeed = null)
Parameters
inputSizeintThe number of input features.
outputSizeintThe number of output features.
priorSigmadoubleThe standard deviation of the prior distribution (default: 1.0).
randomSeedint?Optional random seed for reproducible sampling.
Remarks
For Beginners: The prior sigma controls how spread out the initial weight distributions are. A larger value means more initial uncertainty, a smaller value means the model starts more confident.
Properties
ParameterCount
Gets the total number of parameters in this layer.
public override int ParameterCount { get; }
Property Value
- int
The total number of trainable parameters.
Remarks
This property returns the total number of trainable parameters in the layer. By default, it returns the length of the Parameters vector, but derived classes can override this to calculate the number of parameters differently.
For Beginners: This tells you how many learnable values the layer has.
The parameter count:
- Shows how complex the layer is
- Indicates how many values need to be learned during training
- Can help estimate memory usage and computational requirements
Layers with more parameters can potentially learn more complex patterns but may also require more data to train effectively.
SupportsJitCompilation
Gets whether this layer supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
True if the layer can be JIT compiled, false otherwise.
Remarks
This property indicates whether the layer has implemented ExportComputationGraph() and can benefit from JIT compilation. All layers MUST implement this property.
For Beginners: JIT compilation can make inference 5-10x faster by converting the layer's operations into optimized native code.
Layers should return false if they:
- Have not yet implemented a working ExportComputationGraph()
- Use dynamic operations that change based on input data
- Are too simple to benefit from JIT compilation
When false, the layer will use the standard Forward() method instead.
SupportsTraining
Gets a value indicating whether this layer supports training mode.
public override bool SupportsTraining { get; }
Property Value
Methods
AddKLDivergenceGradients(T)
Adds the KL divergence gradients (regularization term) into the layer's accumulated gradients.
public void AddKLDivergenceGradients(T klScale)
Parameters
klScaleTScaling applied to the KL term (e.g., 1/N for dataset size).
Backward(Tensor<T>)
Performs the backward pass and accumulates gradients.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>
Returns
- Tensor<T>
ClearGradients()
Clears all parameter gradients in this layer.
public override void ClearGradients()
Remarks
This method sets all parameter gradients to zero. This is typically called at the beginning of each batch during training to ensure that gradients from previous batches don't affect the current batch.
For Beginners: This method resets all adjustment values to zero to start fresh.
Clearing gradients:
- Erases all previous adjustment information
- Prepares the layer for a new training batch
- Prevents old adjustments from interfering with new ones
This is typically done at the start of processing each batch of training data to ensure clean, accurate gradient calculations.
ExportComputationGraph(List<ComputationNode<T>>)
Exports the layer's computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the layer's operation.
Remarks
This method constructs a computation graph representation of the layer's forward pass that can be JIT compiled for faster inference. All layers MUST implement this method to support JIT compilation.
For Beginners: JIT (Just-In-Time) compilation converts the layer's operations into optimized native code for 5-10x faster inference.
To support JIT compilation, a layer must:
- Implement this method to export its computation graph
- Set SupportsJitCompilation to true
- Use ComputationNode and TensorOperations to build the graph
All layers are required to implement this method, even if they set SupportsJitCompilation = false.
Forward(Tensor<T>)
Performs the forward pass using sampled weights.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
GetKLDivergence()
Computes the KL divergence between the weight distribution and the prior.
public T GetKLDivergence()
Returns
- T
The KL divergence value.
Remarks
For Beginners: This measures how different the learned weight distributions are from a simple Gaussian prior. This is added to the loss during training to regularize the network and prevent overfitting.
GetParameters()
Gets all trainable parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
ResetState()
Resets the internal state of the layer.
public override void ResetState()
SampleWeights()
Samples weights from the learned distributions.
public void SampleWeights()
SetParameters(Vector<T>)
Sets all trainable parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
UpdateParameters(Vector<T>)
Updates the parameters of the layer with the given vector of parameter values.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters to set.
Remarks
This method sets all the parameters of the layer from a single vector of parameters. The parameters vector must have the correct length to match the total number of parameters in the layer.
For Beginners: This method updates all the learnable values in the layer at once.
When updating parameters:
- The input must be a vector with the correct length
- This replaces all the current parameters with the new ones
- Throws an error if the input doesn't match the expected number of parameters
This is useful for:
- Optimizers that work with all parameters at once
- Applying parameters from another source
- Setting parameters to specific values for testing
Exceptions
- ArgumentException
Thrown when the parameters vector has incorrect length.
UpdateParameters(T)
Updates parameters using the accumulated gradients.
public override void UpdateParameters(T learningRate)
Parameters
learningRateT