Class ENAS<T>
Efficient Neural Architecture Search via Parameter Sharing. ENAS uses a controller RNN to sample architectures and shares weights among child models, achieving 1000x speedup over standard NAS.
Reference: "Efficient Neural Architecture Search via Parameter Sharing" (ICML 2018)
public class ENAS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations
- Inheritance
-
ENAS<T>
- Implements
- Inherited Members
-
AutoMLModelBase<T, Tensor<T>, Tensor<T>>.SetModelEvaluator(IModelEvaluator<T, Tensor<T>, Tensor<T>>)
- Extension Methods
Constructors
ENAS(SearchSpaceBase<T>, int, int, double, double)
public ENAS(SearchSpaceBase<T> searchSpace, int numNodes = 4, int controllerHiddenSize = 100, double baselineDecay = 0.95, double entropyWeight = 0.01)
Parameters
searchSpaceSearchSpaceBase<T>numNodesintcontrollerHiddenSizeintbaselineDecaydoubleentropyWeightdouble
Properties
NasNumNodes
Gets the number of nodes to search over.
protected override int NasNumNodes { get; }
Property Value
NasSearchSpace
Gets the NAS search space.
protected override SearchSpaceBase<T> NasSearchSpace { get; }
Property Value
NumOps
Gets the numeric operations provider for T.
protected override INumericOperations<T> NumOps { get; }
Property Value
- INumericOperations<T>
Methods
CreateInstanceForCopy()
Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.
protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()
Returns
- AutoMLModelBase<T, Tensor<T>, Tensor<T>>
A fresh instance of the derived class with default parameters
Remarks
When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.
GetBaseline()
Gets current baseline value
public T GetBaseline()
Returns
- T
GetControllerGradients()
Gets controller gradients
public List<Vector<T>> GetControllerGradients()
Returns
- List<Vector<T>>
GetControllerParameters()
Gets controller parameters for optimization
public List<Vector<T>> GetControllerParameters()
Returns
- List<Vector<T>>
GetSharedGradients()
Gets shared weight gradients
public Dictionary<string, Vector<T>> GetSharedGradients()
Returns
- Dictionary<string, Vector<T>>
GetSharedWeights()
Gets shared weights for all operations
public Dictionary<string, Vector<T>> GetSharedWeights()
Returns
- Dictionary<string, Vector<T>>
GetSharedWeights(string)
Gets shared weights for a specific operation
public Vector<T> GetSharedWeights(string operationKey)
Parameters
operationKeystring
Returns
- Vector<T>
SampleArchitecture()
Samples an architecture using the controller policy
public (Architecture<T> architecture, T logProb, T entropy) SampleArchitecture()
Returns
- (Architecture<T> architecture, T logProb, T entropy)
SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)
Performs algorithm-specific architecture search.
protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)
Parameters
inputsTensor<T>targetsTensor<T>validationInputsTensor<T>validationTargetsTensor<T>timeLimitTimeSpancancellationTokenCancellationToken
Returns
- Architecture<T>
UpdateController(T, T, T)
Updates controller using REINFORCE policy gradient
public void UpdateController(T reward, T logProb, T entropy)
Parameters
rewardTlogProbTentropyT