Table of Contents

Class SuperNet<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

SuperNet implementation for gradient-based neural architecture search (DARTS). Implements a differentiable architecture search by maintaining architecture parameters (alpha) and network weights simultaneously.

public class SuperNet<T> : IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations

Inheritance
SuperNet<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Constructors

SuperNet(SearchSpaceBase<T>, int, ILossFunction<T>?)

Initializes a new SuperNet for differentiable architecture search.

public SuperNet(SearchSpaceBase<T> searchSpace, int numNodes = 4, ILossFunction<T>? lossFunction = null)

Parameters

searchSpace SearchSpaceBase<T>

The search space defining available operations

numNodes int

Number of nodes in the architecture

lossFunction ILossFunction<T>

Optional loss function to use for training. If null, uses Mean Squared Error (MSE) for neural architecture search.

Fields

NumOps

Provides numeric operations for type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Properties

DefaultLossFunction

Gets the default loss function used by this model for gradient computation.

public ILossFunction<T> DefaultLossFunction { get; }

Property Value

ILossFunction<T>

Remarks

For SuperNet (Neural Architecture Search), the default loss function is Mean Squared Error (MSE), which is used for computing both architecture and weight gradients.

FeatureNames

public string[] FeatureNames { get; set; }

Property Value

string[]

ParameterCount

Gets the number of parameters in the model.

public int ParameterCount { get; }

Property Value

int

Remarks

This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.

SupportsJitCompilation

Gets whether this SuperNet supports JIT compilation.

public bool SupportsJitCompilation { get; }

Property Value

bool

true after at least one forward pass has been performed to initialize weights.

Remarks

SuperNet implements Differentiable Architecture Search (DARTS), which is specifically designed to be differentiable. The softmax-weighted operation mixing that defines DARTS is a fully differentiable computation that can be exported as a computation graph.

Key Insight: While the architecture parameters (alpha) are learned during training, at inference time they are fixed values. The computation graph includes:

  • Softmax over architecture parameters for each node
  • All operation outputs computed in parallel
  • Weighted sum of operation outputs using softmax weights

This is exactly what makes DARTS "differentiable" - the entire forward pass can be expressed as continuous, differentiable operations that are JIT-compilable.

For Beginners: DARTS uses a clever trick called "continuous relaxation":

Instead of choosing ONE operation at each step (which would be discrete and non-differentiable), DARTS computes ALL operations and combines them with softmax weights. This weighted combination IS differentiable and CAN be JIT compiled.

The JIT-compiled SuperNet will:

  • Use the current architecture parameters (alpha values)
  • Compute softmax weights over operations
  • Evaluate all operations
  • Combine outputs using the computed weights

After architecture search is complete, you can also call DeriveArchitecture() to create a simpler, discrete architecture that uses only the best operations.

Type

public ModelType Type { get; }

Property Value

ModelType

Methods

ApplyGradients(Vector<T>, T)

Applies pre-computed gradients to update the model parameters.

public void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>

The gradient vector to apply.

learningRate T

The learning rate for the update.

Remarks

Updates both architecture and weight parameters using: θ = θ - learningRate * gradients

For Beginners: This method applies the gradient updates to both: - Architecture parameters (which operations are selected) - Weight parameters (the neural network weights)

In DARTS, you typically call this with different learning rates for architecture and weight parameters.

Exceptions

ArgumentNullException

If gradients is null.

ArgumentException

If gradient vector length doesn't match parameter count.

BackwardArchitecture(Tensor<T>, Tensor<T>)

Backward pass to compute gradients for architecture parameters

public void BackwardArchitecture(Tensor<T> input, Tensor<T> target)

Parameters

input Tensor<T>
target Tensor<T>

BackwardWeights(Tensor<T>, Tensor<T>, ILossFunction<T>)

Backward pass to compute gradients for network weights using the specified loss function.

public void BackwardWeights(Tensor<T> input, Tensor<T> target, ILossFunction<T> lossFunction)

Parameters

input Tensor<T>

The input tensor.

target Tensor<T>

The target tensor.

lossFunction ILossFunction<T>

The loss function to use for gradient computation.

Clone()

Creates a shallow copy of this object.

public IFullModel<T, Tensor<T>, Tensor<T>> Clone()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

ComputeGradients(Tensor<T>, Tensor<T>, ILossFunction<T>?)

Computes gradients of the loss function with respect to model parameters WITHOUT updating parameters.

public Vector<T> ComputeGradients(Tensor<T> input, Tensor<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Tensor<T>

The input tensor.

target Tensor<T>

The target/expected output tensor.

lossFunction ILossFunction<T>

The loss function to use. If null, uses the model's default loss function.

Returns

Vector<T>

A vector containing gradients with respect to all model parameters (both architecture and weights).

Remarks

For SuperNet, this computes gradients for weight parameters only (not architecture parameters). Architecture parameters are updated separately in DARTS using validation data. The method uses the existing BackwardWeights method and collects gradients from all layers.

For Beginners: SuperNet has two types of parameters: - Architecture parameters (α): which operations to use - Weight parameters (w): the actual neural network weights

This method computes gradients for the weight parameters based on training data. In DARTS, architecture parameters are optimized separately on validation data.

Exceptions

ArgumentNullException

If input or target is null.

ComputeTrainingLoss(Tensor<T>, Tensor<T>)

Computes training loss for weight updates

public T ComputeTrainingLoss(Tensor<T> trainData, Tensor<T> trainLabels)

Parameters

trainData Tensor<T>
trainLabels Tensor<T>

Returns

T

ComputeValidationLoss(Tensor<T>, Tensor<T>)

Computes validation loss for architecture parameter updates

public T ComputeValidationLoss(Tensor<T> valData, Tensor<T> valLabels)

Parameters

valData Tensor<T>
valLabels Tensor<T>

Returns

T

ConfigureFairness(Vector<int>, params FairnessMetric[])

Configures fairness evaluation settings.

public virtual void ConfigureFairness(Vector<int> sensitiveFeatures, params FairnessMetric[] fairnessMetrics)

Parameters

sensitiveFeatures Vector<int>
fairnessMetrics FairnessMetric[]

DeepCopy()

Creates a deep copy of this object.

public IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

DeriveArchitecture()

Derives discrete architecture from continuous parameters (argmax selection)

public Architecture<T> DeriveArchitecture()

Returns

Architecture<T>

Deserialize(byte[])

Loads a previously serialized model from binary data.

public void Deserialize(byte[] data)

Parameters

data byte[]

The byte array containing the serialized model data.

Remarks

This method takes binary data created by the Serialize method and uses it to restore a model to its previous state.

For Beginners: This is like opening a saved file to continue your work.

When you call this method:

  • You provide the binary data (bytes) that was previously created by Serialize
  • The model rebuilds itself using this data
  • After deserializing, the model is exactly as it was when serialized
  • It's ready to make predictions without needing to be trained again

For example:

  • You download a pre-trained model file for detecting spam emails
  • You deserialize this file into your application
  • Immediately, your application can detect spam without any training
  • The model has all the knowledge that was built into it by its original creator

This is particularly useful when:

  • You want to use a model that took days to train
  • You need to deploy the same model across multiple devices
  • You're creating an application that non-technical users will use

Think of it like installing the brain of a trained expert directly into your application.

EnableMethod(params InterpretationMethod[])

Enables specific interpretation methods.

public virtual void EnableMethod(params InterpretationMethod[] methods)

Parameters

methods InterpretationMethod[]

ExportComputationGraph(List<ComputationNode<T>>)

Exports the model's computation graph for JIT compilation.

public ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes (parameters).

Returns

ComputationNode<T>

The output computation node representing the SuperNet forward pass.

Remarks

Exports the DARTS continuous relaxation as a computation graph. The graph includes:

  • Input tensor variable
  • Architecture parameters embedded as constants
  • Softmax computation over architecture parameters
  • All operation outputs
  • Weighted sum using softmax weights

For Beginners: This exports the current state of the SuperNet as a JIT-compilable graph. The architecture parameters (alpha values) are baked into the graph as constants, so the exported graph represents the current "snapshot" of the architecture search.

You can export at different points during training to capture the evolving architecture, or export after search completes to get the final continuous relaxation.

Exceptions

InvalidOperationException

Thrown if called before any forward pass has initialized the weights.

Forward(Tensor<T>)

Performs forward pass through the model (required by IJitCompilable).

public Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor.

Returns

Tensor<T>

The output tensor.

GenerateTextExplanationAsync(Tensor<T>, Tensor<T>)

Generates a text explanation for a prediction. Provides a description of which operations are most important in the SuperNet.

public virtual Task<string> GenerateTextExplanationAsync(Tensor<T> input, Tensor<T> prediction)

Parameters

input Tensor<T>
prediction Tensor<T>

Returns

Task<string>

GetActiveFeatureIndices()

Gets the indices of features that are actively used by this model.

public IEnumerable<int> GetActiveFeatureIndices()

Returns

IEnumerable<int>

GetAnchorExplanationAsync(Tensor<T>, T)

Gets anchor explanation for a given input. Not supported for SuperNet architecture search models.

public virtual Task<AnchorExplanation<T>> GetAnchorExplanationAsync(Tensor<T> input, T threshold)

Parameters

input Tensor<T>
threshold T

Returns

Task<AnchorExplanation<T>>

GetArchitectureGradients()

Gets architecture gradients

public List<Matrix<T>> GetArchitectureGradients()

Returns

List<Matrix<T>>

GetArchitectureParameters()

Gets architecture parameters for optimization

public List<Matrix<T>> GetArchitectureParameters()

Returns

List<Matrix<T>>

GetCounterfactualAsync(Tensor<T>, Tensor<T>, int)

Gets counterfactual explanation for a given input and desired output. Not supported for SuperNet architecture search models.

public virtual Task<CounterfactualExplanation<T>> GetCounterfactualAsync(Tensor<T> input, Tensor<T> desiredOutput, int maxChanges = 5)

Parameters

input Tensor<T>
desiredOutput Tensor<T>
maxChanges int

Returns

Task<CounterfactualExplanation<T>>

GetFeatureImportance()

Gets the feature importance scores.

public Dictionary<string, T> GetFeatureImportance()

Returns

Dictionary<string, T>

GetFeatureInteractionAsync(int, int)

Gets feature interaction effects between two features. Analyzes interactions between operations based on architecture parameter correlations.

public virtual Task<T> GetFeatureInteractionAsync(int feature1Index, int feature2Index)

Parameters

feature1Index int
feature2Index int

Returns

Task<T>

GetGlobalFeatureImportanceAsync(Tensor<T>)

Gets the operation importance for SuperNet architecture search. Returns importance scores for architectural operations rather than input features.

public virtual Task<Dictionary<int, T>> GetGlobalFeatureImportanceAsync(Tensor<T> inputs)

Parameters

inputs Tensor<T>

Input tensor (required for interface compliance; not used in this implementation)

Returns

Task<Dictionary<int, T>>

Dictionary mapping operation indices to their importance scores

Remarks

Note: SuperNet reinterprets "feature importance" as "operation importance" in the context of Neural Architecture Search (NAS). The returned dictionary maps operation indices (0=identity, 1=conv3x3, 2=conv5x5, etc.) to their importance scores, calculated by aggregating the absolute values of architecture parameters across all nodes.

The 'inputs' parameter is required for IInterpretableModel interface compliance but is not used. SuperNet analyzes operation importance based on learned architecture parameters rather than input data.

GetLimeExplanationAsync(Tensor<T>, int)

Gets LIME explanation for a specific input. Not supported for SuperNet architecture search models.

public virtual Task<LimeExplanation<T>> GetLimeExplanationAsync(Tensor<T> input, int numFeatures = 10)

Parameters

input Tensor<T>
numFeatures int

Returns

Task<LimeExplanation<T>>

GetLocalFeatureImportanceAsync(Tensor<T>)

Gets the local feature importance for a specific input. Provides importance based on softmax weights, analyzing which operations are most active.

public virtual Task<Dictionary<int, T>> GetLocalFeatureImportanceAsync(Tensor<T> input)

Parameters

input Tensor<T>

Returns

Task<Dictionary<int, T>>

GetModelMetadata()

Retrieves metadata and performance metrics about the trained model.

public ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

An object containing metadata and performance metrics about the trained model.

Remarks

This method provides information about the model's structure, parameters, and performance metrics.

For Beginners: Model metadata is like a report card for your machine learning model.

Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.

This information typically includes:

  • Accuracy measures: How well does the model's predictions match actual values?
  • Error metrics: How far off are the model's predictions on average?
  • Model parameters: What patterns did the model learn from the data?
  • Training information: How long did training take? How many iterations were needed?

For example, in a house price prediction model, metadata might include:

  • Average prediction error (e.g., off by $15,000 on average)
  • How strongly each feature (bedrooms, location) influences the prediction
  • How well the model fits the training data

This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.

GetModelSpecificInterpretabilityAsync()

Gets model-specific interpretability information for SuperNet. Returns architecture parameters and their importance.

public virtual Task<Dictionary<string, object>> GetModelSpecificInterpretabilityAsync()

Returns

Task<Dictionary<string, object>>

GetParameters()

Gets the parameters that can be optimized.

public Vector<T> GetParameters()

Returns

Vector<T>

GetPartialDependenceAsync(Vector<int>, int)

Gets partial dependence data for specified features. Not supported for SuperNet architecture search models.

public virtual Task<PartialDependenceData<T>> GetPartialDependenceAsync(Vector<int> featureIndices, int gridResolution = 20)

Parameters

featureIndices Vector<int>
gridResolution int

Returns

Task<PartialDependenceData<T>>

GetShapValuesAsync(Tensor<T>)

Gets SHAP values for the given inputs. Not supported for SuperNet architecture search models.

public virtual Task<Matrix<T>> GetShapValuesAsync(Tensor<T> inputs)

Parameters

inputs Tensor<T>

Returns

Task<Matrix<T>>

GetWeightGradients()

Gets weight gradients

public Dictionary<string, Vector<T>> GetWeightGradients()

Returns

Dictionary<string, Vector<T>>

GetWeightParameters()

Gets weight parameters for optimization

public Dictionary<string, Vector<T>> GetWeightParameters()

Returns

Dictionary<string, Vector<T>>

IsFeatureUsed(int)

Checks if a specific feature is used by this model.

public bool IsFeatureUsed(int featureIndex)

Parameters

featureIndex int

Returns

bool

LoadModel(string)

Loads the model from a file.

public void LoadModel(string filePath)

Parameters

filePath string

The path to the file containing the saved model.

Remarks

This method provides a convenient way to load a model directly from disk. It combines file I/O operations with deserialization.

For Beginners: This is like clicking "Open" in a document editor. Instead of manually reading from a file and then calling Deserialize(), this method does both steps for you.

Exceptions

FileNotFoundException

Thrown when the specified file does not exist.

IOException

Thrown when an I/O error occurs while reading from the file or when the file contains corrupted or invalid model data.

LoadState(Stream)

Loads the SuperNet's state (architecture parameters and weights) from a stream.

public virtual void LoadState(Stream stream)

Parameters

stream Stream

The stream to read the model state from.

Remarks

This method deserializes SuperNet state that was previously saved with SaveState, restoring all architecture parameters, operation weights, and configuration. It uses the existing Deserialize method after reading data from the stream.

For Beginners: This is like loading a saved snapshot of your neural architecture search model.

When you call LoadState:

  • All architecture parameters (alpha values) are read from the stream
  • All operation weights are restored
  • The model is configured to match the saved state

After loading, the model can:

  • Continue architecture search from where it left off
  • Make predictions using the restored architecture
  • Be used for further optimization or deployment

This is essential for:

  • Resuming interrupted architecture search
  • Loading the best architecture found during search
  • Deploying searched architectures to production
  • Knowledge distillation workflows

Exceptions

ArgumentNullException

Thrown when stream is null.

IOException

Thrown when there's an error reading from the stream.

InvalidOperationException

Thrown when the stream contains invalid or incompatible data.

Predict(Tensor<T>)

Forward pass through the SuperNet with mixed operations

public Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

Returns

Tensor<T>

SaveModel(string)

Saves the model to a file.

public void SaveModel(string filePath)

Parameters

filePath string

The path where the model should be saved.

Remarks

This method provides a convenient way to save the model directly to disk. It combines serialization with file I/O operations.

For Beginners: This is like clicking "Save As" in a document editor. Instead of manually calling Serialize() and then writing to a file, this method does both steps for you.

Exceptions

IOException

Thrown when an I/O error occurs while writing to the file.

UnauthorizedAccessException

Thrown when the caller does not have the required permission to write to the specified file path.

SaveState(Stream)

Saves the SuperNet's current state (architecture parameters and weights) to a stream.

public virtual void SaveState(Stream stream)

Parameters

stream Stream

The stream to write the model state to.

Remarks

This method serializes all the information needed to recreate the SuperNet's current state, including architecture parameters, operation weights, and model configuration. It uses the existing Serialize method and writes the data to the provided stream.

For Beginners: This is like creating a snapshot of your neural architecture search model.

When you call SaveState:

  • All architecture parameters (alpha values) are written to the stream
  • All operation weights are saved
  • The model's configuration and structure are preserved

This is particularly useful for:

  • Checkpointing during neural architecture search
  • Saving the best architecture found during search
  • Knowledge distillation from SuperNet to final architecture
  • Resuming interrupted architecture search

You can later use LoadState to restore the model to this exact state.

Exceptions

ArgumentNullException

Thrown when stream is null.

IOException

Thrown when there's an error writing to the stream.

Serialize()

Converts the current state of a machine learning model into a binary format.

public byte[] Serialize()

Returns

byte[]

A byte array containing the serialized model data.

Remarks

This method captures all the essential information about a trained model and converts it into a sequence of bytes that can be stored or transmitted.

For Beginners: This is like exporting your work to a file.

When you call this method:

  • The model's current state (all its learned patterns and parameters) is captured
  • This information is converted into a compact binary format (bytes)
  • You can then save these bytes to a file, database, or send them over a network

For example:

  • After training a model to recognize cats vs. dogs in images
  • You can serialize the model to save all its learned knowledge
  • Later, you can use this saved data to recreate the model exactly as it was
  • The recreated model will make the same predictions as the original

Think of it like taking a snapshot of your model's brain at a specific moment in time.

SetActiveFeatureIndices(IEnumerable<int>)

Sets the active feature indices for this model.

public void SetActiveFeatureIndices(IEnumerable<int> featureIndices)

Parameters

featureIndices IEnumerable<int>

SetBaseModel(IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>)

Sets the base model for interpretability analysis.

public virtual void SetBaseModel(IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>> model)

Parameters

model IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>

SetParameters(Vector<T>)

Sets the model parameters.

public void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The parameter vector to set.

Remarks

This method allows direct modification of the model's internal parameters. This is useful for optimization algorithms that need to update parameters iteratively. If the length of parameters does not match ParameterCount, an ArgumentException should be thrown.

Exceptions

ArgumentException

Thrown when the length of parameters does not match ParameterCount.

Train(Tensor<T>, Tensor<T>)

Training is handled externally by alternating architecture and weight updates

public void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>
expectedOutput Tensor<T>

ValidateFairnessAsync(Tensor<T>, int)

Validates fairness metrics for the given inputs. Not supported for SuperNet architecture search models.

public virtual Task<FairnessMetrics<T>> ValidateFairnessAsync(Tensor<T> inputs, int sensitiveFeatureIndex)

Parameters

inputs Tensor<T>
sensitiveFeatureIndex int

Returns

Task<FairnessMetrics<T>>

WithParameters(Vector<T>)

Creates a new instance with the specified parameters.

public IFullModel<T, Tensor<T>, Tensor<T>> WithParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

Returns

IFullModel<T, Tensor<T>, Tensor<T>>