Class BigNAS<T>
BigNAS: Scaling Up Neural Architecture Search with Big Single-Stage Models. Combines sandwich sampling with in-place knowledge distillation to train very large super-networks that can adapt to various deployment scenarios.
Reference: "BigNAS: Scaling Up Neural Architecture Search with Big Single-Stage Models"
public class BigNAS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations
- Inheritance
-
BigNAS<T>
- Implements
- Inherited Members
-
AutoMLModelBase<T, Tensor<T>, Tensor<T>>.SetModelEvaluator(IModelEvaluator<T, Tensor<T>, Tensor<T>>)
- Extension Methods
Constructors
BigNAS(SearchSpaceBase<T>, List<int>?, List<double>?, List<int>?, List<int>?, List<int>?, bool, double)
public BigNAS(SearchSpaceBase<T> searchSpace, List<int>? elasticDepths = null, List<double>? elasticWidthMultipliers = null, List<int>? elasticKernelSizes = null, List<int>? elasticExpansionRatios = null, List<int>? elasticResolutions = null, bool useSandwichSampling = true, double distillationWeight = 0.5)
Parameters
searchSpaceSearchSpaceBase<T>elasticDepthsList<int>elasticWidthMultipliersList<double>elasticKernelSizesList<int>elasticExpansionRatiosList<int>elasticResolutionsList<int>useSandwichSamplingbooldistillationWeightdouble
Properties
NasNumNodes
Gets the number of nodes to search over.
protected override int NasNumNodes { get; }
Property Value
NasSearchSpace
Gets the NAS search space.
protected override SearchSpaceBase<T> NasSearchSpace { get; }
Property Value
NumOps
Gets the numeric operations provider for T.
protected override INumericOperations<T> NumOps { get; }
Property Value
- INumericOperations<T>
Methods
ComputeDistillationLoss(Vector<T>, Vector<T>, T)
Computes knowledge distillation loss between teacher and student networks
public T ComputeDistillationLoss(Vector<T> teacherLogits, Vector<T> studentLogits, T temperature)
Parameters
teacherLogitsVector<T>studentLogitsVector<T>temperatureT
Returns
- T
CreateInstanceForCopy()
Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.
protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()
Returns
- AutoMLModelBase<T, Tensor<T>, Tensor<T>>
A fresh instance of the derived class with default parameters
Remarks
When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.
MultiObjectiveSearch(List<(string name, HardwareConstraints<T> constraints)>, int, int, int, int)
Searches for optimal sub-networks for multiple hardware constraints simultaneously
public Dictionary<string, BigNASConfig> MultiObjectiveSearch(List<(string name, HardwareConstraints<T> constraints)> targetDevices, int inputChannels, int spatialSize, int populationSize = 100, int generations = 50)
Parameters
targetDevicesList<(string name, HardwareConstraints<T> constraints)>inputChannelsintspatialSizeintpopulationSizeintgenerationsint
Returns
SandwichSample()
Sandwich sampling: samples smallest, largest, and random sub-networks together This improves training efficiency and performance of all sub-networks
public List<BigNASConfig> SandwichSample()
Returns
SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)
Performs algorithm-specific architecture search.
protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)
Parameters
inputsTensor<T>targetsTensor<T>validationInputsTensor<T>validationTargetsTensor<T>timeLimitTimeSpancancellationTokenCancellationToken
Returns
- Architecture<T>