Class SparseNeuralNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Sparse Neural Network with efficient sparse weight matrices.
public class SparseNeuralNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
SparseNeuralNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
A Sparse Neural Network uses sparse weight matrices where most values are zero. This provides significant memory and computational savings for large networks, especially when combined with network pruning techniques.
For Beginners: In a regular neural network, every neuron in one layer is connected to every neuron in the next layer. In a sparse network, many of these connections are removed (set to zero), keeping only the most important ones. This has several benefits: - Uses less memory (only stores non-zero values) - Runs faster (skips multiplications with zero) - Can prevent overfitting (acts as regularization) - Enables very large networks to fit in limited memory
Common use cases include:
- Network compression for mobile/edge deployment
- Recommender systems with sparse user-item matrices
- Graph neural networks with sparse adjacency matrices
- Pruned networks from neural architecture search
Constructors
SparseNeuralNetwork(NeuralNetworkArchitecture<T>, double, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the SparseNeuralNetwork class.
public SparseNeuralNetwork(NeuralNetworkArchitecture<T> architecture, double sparsity = 0.9, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture defining the structure of the neural network.
sparsitydoubleThe fraction of weights that should be zero (0.0 to 1.0). Default is 0.9 (90% sparse).
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>The optimization algorithm to use for training. If null, Adam optimizer is used.
lossFunctionILossFunction<T>The loss function to use for training. If null, MSE is used.
maxGradNormdoubleThe maximum gradient norm for gradient clipping during training.
Remarks
Higher sparsity values mean fewer connections and faster computation, but may reduce the network's capacity to learn complex patterns. A sparsity of 0.9 (90% zeros) is a good starting point for most applications.
Properties
SupportsTraining
Indicates whether this network supports training.
public override bool SupportsTraining { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs a backward pass through the network to calculate gradients.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the network's output.
Returns
- Tensor<T>
The gradient of the loss with respect to the network's input.
Remarks
During backpropagation, gradients are only computed for non-zero weights, maintaining the sparsity pattern throughout training.
CreateNewInstance()
Creates a new instance of the SparseNeuralNetwork with the same configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes sparse neural network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
Forward(Tensor<T>)
Performs a forward pass through the network with the given input tensor.
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process.
Returns
- Tensor<T>
The output tensor after processing through all layers.
Remarks
The forward pass uses sparse matrix-vector multiplication (SpMV) for efficiency. Only non-zero weights are used in computation, significantly reducing the number of operations for highly sparse networks.
GetModelMetadata()
Retrieves metadata about the sparse neural network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the network.
InitializeLayers()
Initializes the layers of the sparse neural network based on the provided architecture.
protected override void InitializeLayers()
IsValidInputLayer(ILayer<T>)
Determines if a layer can serve as a valid input layer for this network.
protected override bool IsValidInputLayer(ILayer<T> layer)
Parameters
layerILayer<T>
Returns
IsValidOutputLayer(ILayer<T>)
Determines if a layer can serve as a valid output layer for this network.
protected override bool IsValidOutputLayer(ILayer<T> layer)
Parameters
layerILayer<T>
Returns
Predict(Tensor<T>)
Makes a prediction using the sparse neural network for the given input tensor.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to make a prediction for.
Returns
- Tensor<T>
The predicted output tensor.
SerializeNetworkSpecificData(BinaryWriter)
Serializes sparse neural network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
Train(Tensor<T>, Tensor<T>)
Trains the sparse neural network using the provided input and expected output.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor for training.
expectedOutputTensor<T>The expected output tensor for the given input.
Remarks
Training maintains the sparsity pattern - only non-zero weights are updated. This means the network structure is fixed after initialization; use dynamic sparsity techniques if you need the sparsity pattern to evolve during training.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.