Interface IPruningMask<T>
- Namespace
- AiDotNet.Interfaces
- Assembly
- AiDotNet.dll
Represents a binary mask for pruning weights in a neural network layer.
public interface IPruningMask<T>
Type Parameters
TNumeric type for mask values
Remarks
A pruning mask is a binary matrix that determines which weights to keep (1) and which to remove (0) during model compression. It enables selective removal of network parameters while maintaining the ability to restore the network structure.
For Beginners: Think of a pruning mask as a stencil or template.
Imagine you're painting a picture and want to cover certain areas:
- The mask has holes (1s) where paint should go through (weights to keep)
- The mask is solid (0s) where paint should be blocked (weights to prune/remove)
In neural networks:
- A pruning mask helps you selectively remove less important connections
- This makes your model smaller and faster without losing too much accuracy
- The mask can be applied to weight matrices to zero out pruned weights
Properties
Pattern
Gets the sparsity pattern type.
SparsityPattern Pattern { get; }
Property Value
Shape
Gets the mask dimensions matching the weight matrix shape.
int[] Shape { get; }
Property Value
- int[]
Methods
Apply(Matrix<T>)
Applies the mask to a weight matrix (element-wise multiplication).
Matrix<T> Apply(Matrix<T> weights)
Parameters
weightsMatrix<T>Weight matrix to prune
Returns
- Matrix<T>
Pruned weights (zeros where mask is zero)
Apply(Tensor<T>)
Applies the mask to a weight tensor (for convolutional layers).
Tensor<T> Apply(Tensor<T> weights)
Parameters
weightsTensor<T>
Returns
- Tensor<T>
Apply(Vector<T>)
Applies the mask to a vector.
Vector<T> Apply(Vector<T> weights)
Parameters
weightsVector<T>
Returns
- Vector<T>
CombineWith(IPruningMask<T>)
Combines this mask with another mask (logical AND).
IPruningMask<T> CombineWith(IPruningMask<T> otherMask)
Parameters
otherMaskIPruningMask<T>
Returns
- IPruningMask<T>
GetKeptIndices()
Gets indices of non-zero (kept) elements.
int[] GetKeptIndices()
Returns
- int[]
GetMaskData()
Gets the raw mask data as a flat array.
T[] GetMaskData()
Returns
- T[]
GetPrunedIndices()
Gets indices of zero (pruned) elements.
int[] GetPrunedIndices()
Returns
- int[]
GetSparsity()
Gets the sparsity ratio (proportion of zeros).
double GetSparsity()
Returns
- double
Value between 0 (dense) and 1 (fully pruned)
Remarks
For Beginners: Sparsity measures how many weights have been removed. - 0.0 means no weights removed (0% sparse, 100% dense) - 0.5 means half the weights removed (50% sparse) - 0.9 means 90% of weights removed (90% sparse)
UpdateMask(Array)
Updates the mask with new N-D keep/prune decisions.
void UpdateMask(Array keepIndices)
Parameters
keepIndicesArray
UpdateMask(bool[,])
Updates the mask based on new pruning criteria.
void UpdateMask(bool[,] keepIndices)
Parameters
keepIndicesbool[,]Indices of weights to keep (not prune)
UpdateMask(bool[])
Updates the mask with new keep/prune decisions.
void UpdateMask(bool[] keepIndices)
Parameters
keepIndicesbool[]