Class ProxylessNAS<T>
ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware. Uses path binarization and latency-aware loss to search directly on the target device without requiring a proxy task or separate hardware lookup tables.
Reference: "ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware" (ICLR 2019)
public class ProxylessNAS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations
- Inheritance
-
ProxylessNAS<T>
- Implements
- Inherited Members
-
AutoMLModelBase<T, Tensor<T>, Tensor<T>>.SetModelEvaluator(IModelEvaluator<T, Tensor<T>, Tensor<T>>)
- Extension Methods
Constructors
ProxylessNAS(SearchSpaceBase<T>, int, HardwarePlatform, double, bool)
public ProxylessNAS(SearchSpaceBase<T> searchSpace, int numNodes = 4, HardwarePlatform targetPlatform = HardwarePlatform.Mobile, double latencyWeight = 0.1, bool useBinarization = true)
Parameters
searchSpaceSearchSpaceBase<T>numNodesinttargetPlatformHardwarePlatformlatencyWeightdoubleuseBinarizationbool
Properties
NasNumNodes
Gets the number of nodes to search over.
protected override int NasNumNodes { get; }
Property Value
NasSearchSpace
Gets the NAS search space.
protected override SearchSpaceBase<T> NasSearchSpace { get; }
Property Value
NumOps
Gets the numeric operations provider for T.
protected override INumericOperations<T> NumOps { get; }
Property Value
- INumericOperations<T>
Methods
BinarizePaths(Matrix<T>)
Applies path binarization for memory-efficient single-path sampling. Only one operation is active at a time during forward pass.
public Matrix<T> BinarizePaths(Matrix<T> alpha)
Parameters
alphaMatrix<T>
Returns
- Matrix<T>
ComputeExpectedLatency(int, int)
Computes the expected latency cost of the architecture
public T ComputeExpectedLatency(int inputChannels, int spatialSize)
Parameters
Returns
- T
ComputeTotalLoss(T, int, int)
Computes the total loss including task loss and latency regularization
public T ComputeTotalLoss(T taskLoss, int inputChannels, int spatialSize)
Parameters
Returns
- T
CreateInstanceForCopy()
Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.
protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()
Returns
- AutoMLModelBase<T, Tensor<T>, Tensor<T>>
A fresh instance of the derived class with default parameters
Remarks
When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.
DeriveArchitecture()
Derives the final discrete architecture by selecting operations with highest weights
public Architecture<T> DeriveArchitecture()
Returns
- Architecture<T>
EstimateArchitectureCost(int, int)
Estimates the final architecture's hardware cost
public HardwareCost<T> EstimateArchitectureCost(int inputChannels, int spatialSize)
Parameters
Returns
- HardwareCost<T>
GetArchitectureGradients()
Gets architecture gradients
public List<Matrix<T>> GetArchitectureGradients()
Returns
- List<Matrix<T>>
GetArchitectureParameters()
Gets architecture parameters for optimization
public List<Matrix<T>> GetArchitectureParameters()
Returns
- List<Matrix<T>>
SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)
Performs algorithm-specific architecture search.
protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)
Parameters
inputsTensor<T>targetsTensor<T>validationInputsTensor<T>validationTargetsTensor<T>timeLimitTimeSpancancellationTokenCancellationToken
Returns
- Architecture<T>
SetBinarizationTemperature(double)
Sets the binarization temperature
public void SetBinarizationTemperature(double temperature)
Parameters
temperaturedouble