Class SuperNet<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
SuperNet implementation for gradient-based neural architecture search (DARTS). Implements a differentiable architecture search by maintaining architecture parameters (alpha) and network weights simultaneously.
public class SuperNet<T> : IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations
- Inheritance
-
SuperNet<T>
- Implements
- Inherited Members
- Extension Methods
Constructors
SuperNet(SearchSpaceBase<T>, int, ILossFunction<T>?)
Initializes a new SuperNet for differentiable architecture search.
public SuperNet(SearchSpaceBase<T> searchSpace, int numNodes = 4, ILossFunction<T>? lossFunction = null)
Parameters
searchSpaceSearchSpaceBase<T>The search space defining available operations
numNodesintNumber of nodes in the architecture
lossFunctionILossFunction<T>Optional loss function to use for training. If null, uses Mean Squared Error (MSE) for neural architecture search.
Fields
NumOps
Provides numeric operations for type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Properties
DefaultLossFunction
Gets the default loss function used by this model for gradient computation.
public ILossFunction<T> DefaultLossFunction { get; }
Property Value
Remarks
For SuperNet (Neural Architecture Search), the default loss function is Mean Squared Error (MSE), which is used for computing both architecture and weight gradients.
FeatureNames
public string[] FeatureNames { get; set; }
Property Value
- string[]
ParameterCount
Gets the number of parameters in the model.
public int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
SupportsJitCompilation
Gets whether this SuperNet supports JIT compilation.
public bool SupportsJitCompilation { get; }
Property Value
- bool
trueafter at least one forward pass has been performed to initialize weights.
Remarks
SuperNet implements Differentiable Architecture Search (DARTS), which is specifically designed to be differentiable. The softmax-weighted operation mixing that defines DARTS is a fully differentiable computation that can be exported as a computation graph.
Key Insight: While the architecture parameters (alpha) are learned during training, at inference time they are fixed values. The computation graph includes:
- Softmax over architecture parameters for each node
- All operation outputs computed in parallel
- Weighted sum of operation outputs using softmax weights
This is exactly what makes DARTS "differentiable" - the entire forward pass can be expressed as continuous, differentiable operations that are JIT-compilable.
For Beginners: DARTS uses a clever trick called "continuous relaxation":
Instead of choosing ONE operation at each step (which would be discrete and non-differentiable), DARTS computes ALL operations and combines them with softmax weights. This weighted combination IS differentiable and CAN be JIT compiled.
The JIT-compiled SuperNet will:
- Use the current architecture parameters (alpha values)
- Compute softmax weights over operations
- Evaluate all operations
- Combine outputs using the computed weights
After architecture search is complete, you can also call DeriveArchitecture() to create a simpler, discrete architecture that uses only the best operations.
Type
public ModelType Type { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies pre-computed gradients to update the model parameters.
public void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>The gradient vector to apply.
learningRateTThe learning rate for the update.
Remarks
Updates both architecture and weight parameters using: θ = θ - learningRate * gradients
For Beginners: This method applies the gradient updates to both: - Architecture parameters (which operations are selected) - Weight parameters (the neural network weights)
In DARTS, you typically call this with different learning rates for architecture and weight parameters.
Exceptions
- ArgumentNullException
If gradients is null.
- ArgumentException
If gradient vector length doesn't match parameter count.
BackwardArchitecture(Tensor<T>, Tensor<T>)
Backward pass to compute gradients for architecture parameters
public void BackwardArchitecture(Tensor<T> input, Tensor<T> target)
Parameters
inputTensor<T>targetTensor<T>
BackwardWeights(Tensor<T>, Tensor<T>, ILossFunction<T>)
Backward pass to compute gradients for network weights using the specified loss function.
public void BackwardWeights(Tensor<T> input, Tensor<T> target, ILossFunction<T> lossFunction)
Parameters
inputTensor<T>The input tensor.
targetTensor<T>The target tensor.
lossFunctionILossFunction<T>The loss function to use for gradient computation.
Clone()
Creates a shallow copy of this object.
public IFullModel<T, Tensor<T>, Tensor<T>> Clone()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
ComputeGradients(Tensor<T>, Tensor<T>, ILossFunction<T>?)
Computes gradients of the loss function with respect to model parameters WITHOUT updating parameters.
public Vector<T> ComputeGradients(Tensor<T> input, Tensor<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputTensor<T>The input tensor.
targetTensor<T>The target/expected output tensor.
lossFunctionILossFunction<T>The loss function to use. If null, uses the model's default loss function.
Returns
- Vector<T>
A vector containing gradients with respect to all model parameters (both architecture and weights).
Remarks
For SuperNet, this computes gradients for weight parameters only (not architecture parameters). Architecture parameters are updated separately in DARTS using validation data. The method uses the existing BackwardWeights method and collects gradients from all layers.
For Beginners: SuperNet has two types of parameters: - Architecture parameters (α): which operations to use - Weight parameters (w): the actual neural network weights
This method computes gradients for the weight parameters based on training data. In DARTS, architecture parameters are optimized separately on validation data.
Exceptions
- ArgumentNullException
If input or target is null.
ComputeTrainingLoss(Tensor<T>, Tensor<T>)
Computes training loss for weight updates
public T ComputeTrainingLoss(Tensor<T> trainData, Tensor<T> trainLabels)
Parameters
trainDataTensor<T>trainLabelsTensor<T>
Returns
- T
ComputeValidationLoss(Tensor<T>, Tensor<T>)
Computes validation loss for architecture parameter updates
public T ComputeValidationLoss(Tensor<T> valData, Tensor<T> valLabels)
Parameters
valDataTensor<T>valLabelsTensor<T>
Returns
- T
ConfigureFairness(Vector<int>, params FairnessMetric[])
Configures fairness evaluation settings.
public virtual void ConfigureFairness(Vector<int> sensitiveFeatures, params FairnessMetric[] fairnessMetrics)
Parameters
sensitiveFeaturesVector<int>fairnessMetricsFairnessMetric[]
DeepCopy()
Creates a deep copy of this object.
public IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeriveArchitecture()
Derives discrete architecture from continuous parameters (argmax selection)
public Architecture<T> DeriveArchitecture()
Returns
- Architecture<T>
Deserialize(byte[])
Loads a previously serialized model from binary data.
public void Deserialize(byte[] data)
Parameters
databyte[]The byte array containing the serialized model data.
Remarks
This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.
For Beginners: This is like opening a saved file to continue your work.
When you call this method:
- You provide the binary data (bytes) that was previously created by Serialize
- The model rebuilds itself using this data
- After deserializing, the model is exactly as it was when serialized
- It's ready to make predictions without needing to be trained again
For example:
- You download a pre-trained model file for detecting spam emails
- You deserialize this file into your application
- Immediately, your application can detect spam without any training
- The model has all the knowledge that was built into it by its original creator
This is particularly useful when:
- You want to use a model that took days to train
- You need to deploy the same model across multiple devices
- You're creating an application that non-technical users will use
Think of it like installing the brain of a trained expert directly into your application.
EnableMethod(params InterpretationMethod[])
Enables specific interpretation methods.
public virtual void EnableMethod(params InterpretationMethod[] methods)
Parameters
methodsInterpretationMethod[]
ExportComputationGraph(List<ComputationNode<T>>)
Exports the model's computation graph for JIT compilation.
public ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes (parameters).
Returns
- ComputationNode<T>
The output computation node representing the SuperNet forward pass.
Remarks
Exports the DARTS continuous relaxation as a computation graph. The graph includes:
- Input tensor variable
- Architecture parameters embedded as constants
- Softmax computation over architecture parameters
- All operation outputs
- Weighted sum using softmax weights
For Beginners: This exports the current state of the SuperNet as a JIT-compilable graph. The architecture parameters (alpha values) are baked into the graph as constants, so the exported graph represents the current "snapshot" of the architecture search.
You can export at different points during training to capture the evolving architecture, or export after search completes to get the final continuous relaxation.
Exceptions
- InvalidOperationException
Thrown if called before any forward pass has initialized the weights.
Forward(Tensor<T>)
Performs forward pass through the model (required by IJitCompilable).
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>The input tensor.
Returns
- Tensor<T>
The output tensor.
GenerateTextExplanationAsync(Tensor<T>, Tensor<T>)
Generates a text explanation for a prediction. Provides a description of which operations are most important in the SuperNet.
public virtual Task<string> GenerateTextExplanationAsync(Tensor<T> input, Tensor<T> prediction)
Parameters
inputTensor<T>predictionTensor<T>
Returns
GetActiveFeatureIndices()
Gets the indices of features that are actively used by this model.
public IEnumerable<int> GetActiveFeatureIndices()
Returns
GetAnchorExplanationAsync(Tensor<T>, T)
Gets anchor explanation for a given input. Not supported for SuperNet architecture search models.
public virtual Task<AnchorExplanation<T>> GetAnchorExplanationAsync(Tensor<T> input, T threshold)
Parameters
inputTensor<T>thresholdT
Returns
GetArchitectureGradients()
Gets architecture gradients
public List<Matrix<T>> GetArchitectureGradients()
Returns
- List<Matrix<T>>
GetArchitectureParameters()
Gets architecture parameters for optimization
public List<Matrix<T>> GetArchitectureParameters()
Returns
- List<Matrix<T>>
GetCounterfactualAsync(Tensor<T>, Tensor<T>, int)
Gets counterfactual explanation for a given input and desired output. Not supported for SuperNet architecture search models.
public virtual Task<CounterfactualExplanation<T>> GetCounterfactualAsync(Tensor<T> input, Tensor<T> desiredOutput, int maxChanges = 5)
Parameters
inputTensor<T>desiredOutputTensor<T>maxChangesint
Returns
GetFeatureImportance()
Gets the feature importance scores.
public Dictionary<string, T> GetFeatureImportance()
Returns
- Dictionary<string, T>
GetFeatureInteractionAsync(int, int)
Gets feature interaction effects between two features. Analyzes interactions between operations based on architecture parameter correlations.
public virtual Task<T> GetFeatureInteractionAsync(int feature1Index, int feature2Index)
Parameters
Returns
- Task<T>
GetGlobalFeatureImportanceAsync(Tensor<T>)
Gets the operation importance for SuperNet architecture search. Returns importance scores for architectural operations rather than input features.
public virtual Task<Dictionary<int, T>> GetGlobalFeatureImportanceAsync(Tensor<T> inputs)
Parameters
inputsTensor<T>Input tensor (required for interface compliance; not used in this implementation)
Returns
- Task<Dictionary<int, T>>
Dictionary mapping operation indices to their importance scores
Remarks
Note: SuperNet reinterprets "feature importance" as "operation importance" in the context of Neural Architecture Search (NAS). The returned dictionary maps operation indices (0=identity, 1=conv3x3, 2=conv5x5, etc.) to their importance scores, calculated by aggregating the absolute values of architecture parameters across all nodes.
The 'inputs' parameter is required for IInterpretableModel interface compliance but is not used. SuperNet analyzes operation importance based on learned architecture parameters rather than input data.
GetLimeExplanationAsync(Tensor<T>, int)
Gets LIME explanation for a specific input. Not supported for SuperNet architecture search models.
public virtual Task<LimeExplanation<T>> GetLimeExplanationAsync(Tensor<T> input, int numFeatures = 10)
Parameters
inputTensor<T>numFeaturesint
Returns
- Task<LimeExplanation<T>>
GetLocalFeatureImportanceAsync(Tensor<T>)
Gets the local feature importance for a specific input. Provides importance based on softmax weights, analyzing which operations are most active.
public virtual Task<Dictionary<int, T>> GetLocalFeatureImportanceAsync(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Task<Dictionary<int, T>>
GetModelMetadata()
Retrieves metadata and performance metrics about the trained model.
public ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
An object containing metadata and performance metrics about the trained model.
Remarks
This method provides information about the model's structure, parameters, and performance metrics.
For Beginners: Model metadata is like a report card for your machine learning model.
Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.
This information typically includes:
- Accuracy measures: How well does the model's predictions match actual values?
- Error metrics: How far off are the model's predictions on average?
- Model parameters: What patterns did the model learn from the data?
- Training information: How long did training take? How many iterations were needed?
For example, in a house price prediction model, metadata might include:
- Average prediction error (e.g., off by $15,000 on average)
- How strongly each feature (bedrooms, location) influences the prediction
- How well the model fits the training data
This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.
GetModelSpecificInterpretabilityAsync()
Gets model-specific interpretability information for SuperNet. Returns architecture parameters and their importance.
public virtual Task<Dictionary<string, object>> GetModelSpecificInterpretabilityAsync()
Returns
GetParameters()
Gets the parameters that can be optimized.
public Vector<T> GetParameters()
Returns
- Vector<T>
GetPartialDependenceAsync(Vector<int>, int)
Gets partial dependence data for specified features. Not supported for SuperNet architecture search models.
public virtual Task<PartialDependenceData<T>> GetPartialDependenceAsync(Vector<int> featureIndices, int gridResolution = 20)
Parameters
Returns
GetShapValuesAsync(Tensor<T>)
Gets SHAP values for the given inputs. Not supported for SuperNet architecture search models.
public virtual Task<Matrix<T>> GetShapValuesAsync(Tensor<T> inputs)
Parameters
inputsTensor<T>
Returns
- Task<Matrix<T>>
GetWeightGradients()
Gets weight gradients
public Dictionary<string, Vector<T>> GetWeightGradients()
Returns
- Dictionary<string, Vector<T>>
GetWeightParameters()
Gets weight parameters for optimization
public Dictionary<string, Vector<T>> GetWeightParameters()
Returns
- Dictionary<string, Vector<T>>
IsFeatureUsed(int)
Checks if a specific feature is used by this model.
public bool IsFeatureUsed(int featureIndex)
Parameters
featureIndexint
Returns
LoadModel(string)
Loads the model from a file.
public void LoadModel(string filePath)
Parameters
filePathstringThe path to the file containing the saved model.
Remarks
This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.
For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.
Exceptions
- FileNotFoundException
Thrown when the specified file does not exist.
- IOException
Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.
LoadState(Stream)
Loads the SuperNet's state (architecture parameters and weights) from a stream.
public virtual void LoadState(Stream stream)
Parameters
streamStreamThe stream to read the model state from.
Remarks
This method deserializes SuperNet state that was previously saved with SaveState, restoring all architecture parameters, operation weights, and configuration. It uses the existing Deserialize method after reading data from the stream.
For Beginners: This is like loading a saved snapshot of your neural architecture search model.
When you call LoadState:
- All architecture parameters (alpha values) are read from the stream
- All operation weights are restored
- The model is configured to match the saved state
After loading, the model can:
- Continue architecture search from where it left off
- Make predictions using the restored architecture
- Be used for further optimization or deployment
This is essential for:
- Resuming interrupted architecture search
- Loading the best architecture found during search
- Deploying searched architectures to production
- Knowledge distillation workflows
Exceptions
- ArgumentNullException
Thrown when stream is null.
- IOException
Thrown when there's an error reading from the stream.
- InvalidOperationException
Thrown when the stream contains invalid or incompatible data.
Predict(Tensor<T>)
Forward pass through the SuperNet with mixed operations
public Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
SaveModel(string)
Saves the model to a file.
public void SaveModel(string filePath)
Parameters
filePathstringThe path where the model should be saved.
Remarks
This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.
For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.
Exceptions
- IOException
Thrown when an I/O error occurs while writing to the file.
- UnauthorizedAccessException
Thrown when the caller does not have the required permission to write to the specified file path.
SaveState(Stream)
Saves the SuperNet's current state (architecture parameters and weights) to a stream.
public virtual void SaveState(Stream stream)
Parameters
streamStreamThe stream to write the model state to.
Remarks
This method serializes all the information needed to recreate the SuperNet's current state, including architecture parameters, operation weights, and model configuration. It uses the existing Serialize method and writes the data to the provided stream.
For Beginners: This is like creating a snapshot of your neural architecture search model.
When you call SaveState:
- All architecture parameters (alpha values) are written to the stream
- All operation weights are saved
- The model's configuration and structure are preserved
This is particularly useful for:
- Checkpointing during neural architecture search
- Saving the best architecture found during search
- Knowledge distillation from SuperNet to final architecture
- Resuming interrupted architecture search
You can later use LoadState to restore the model to this exact state.
Exceptions
- ArgumentNullException
Thrown when stream is null.
- IOException
Thrown when there's an error writing to the stream.
Serialize()
Converts the current state of a machine learning model into a binary format.
public byte[] Serialize()
Returns
- byte[]
A byte array containing the serialized model data.
Remarks
This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.
For Beginners: This is like exporting your work to a file.
When you call this method:
- The model's current state (all its learned patterns and parameters) is captured
- This information is converted into a compact binary format (bytes)
- You can then save these bytes to a file, database, or send them over a network
For example:
- After training a model to recognize cats vs. dogs in images
- You can serialize the model to save all its learned knowledge
- Later, you can use this saved data to recreate the model exactly as it was
- The recreated model will make the same predictions as the original
Think of it like taking a snapshot of your model's brain at a specific moment in time.
SetActiveFeatureIndices(IEnumerable<int>)
Sets the active feature indices for this model.
public void SetActiveFeatureIndices(IEnumerable<int> featureIndices)
Parameters
featureIndicesIEnumerable<int>
SetBaseModel(IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>)
Sets the base model for interpretability analysis.
public virtual void SetBaseModel(IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>> model)
Parameters
modelIModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
SetParameters(Vector<T>)
Sets the model parameters.
public void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to set.
Remarks
This method allows direct modification of the model's internal parameters.
This is useful for optimization algorithms that need to update parameters iteratively.
If the length of parameters does not match ParameterCount,
an ArgumentException should be thrown.
Exceptions
- ArgumentException
Thrown when the length of
parametersdoes not match ParameterCount.
Train(Tensor<T>, Tensor<T>)
Training is handled externally by alternating architecture and weight updates
public void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>expectedOutputTensor<T>
ValidateFairnessAsync(Tensor<T>, int)
Validates fairness metrics for the given inputs. Not supported for SuperNet architecture search models.
public virtual Task<FairnessMetrics<T>> ValidateFairnessAsync(Tensor<T> inputs, int sensitiveFeatureIndex)
Parameters
inputsTensor<T>sensitiveFeatureIndexint
Returns
- Task<FairnessMetrics<T>>
WithParameters(Vector<T>)
Creates a new instance with the specified parameters.
public IFullModel<T, Tensor<T>, Tensor<T>> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>