Table of Contents

Class ENAS<T>

Namespace
AiDotNet.AutoML.NAS
Assembly
AiDotNet.dll

Efficient Neural Architecture Search via Parameter Sharing. ENAS uses a controller RNN to sample architectures and shares weights among child models, achieving 1000x speedup over standard NAS.

Reference: "Efficient Neural Architecture Search via Parameter Sharing" (ICML 2018)

public class ENAS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations

Inheritance
AutoMLModelBase<T, Tensor<T>, Tensor<T>>
ENAS<T>
Implements
IAutoMLModel<T, Tensor<T>, Tensor<T>>
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Constructors

ENAS(SearchSpaceBase<T>, int, int, double, double)

public ENAS(SearchSpaceBase<T> searchSpace, int numNodes = 4, int controllerHiddenSize = 100, double baselineDecay = 0.95, double entropyWeight = 0.01)

Parameters

searchSpace SearchSpaceBase<T>
numNodes int
controllerHiddenSize int
baselineDecay double
entropyWeight double

Properties

NasNumNodes

Gets the number of nodes to search over.

protected override int NasNumNodes { get; }

Property Value

int

NasSearchSpace

Gets the NAS search space.

protected override SearchSpaceBase<T> NasSearchSpace { get; }

Property Value

SearchSpaceBase<T>

NumOps

Gets the numeric operations provider for T.

protected override INumericOperations<T> NumOps { get; }

Property Value

INumericOperations<T>

Methods

CreateInstanceForCopy()

Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.

protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()

Returns

AutoMLModelBase<T, Tensor<T>, Tensor<T>>

A fresh instance of the derived class with default parameters

Remarks

When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.

GetBaseline()

Gets current baseline value

public T GetBaseline()

Returns

T

GetControllerGradients()

Gets controller gradients

public List<Vector<T>> GetControllerGradients()

Returns

List<Vector<T>>

GetControllerParameters()

Gets controller parameters for optimization

public List<Vector<T>> GetControllerParameters()

Returns

List<Vector<T>>

GetSharedGradients()

Gets shared weight gradients

public Dictionary<string, Vector<T>> GetSharedGradients()

Returns

Dictionary<string, Vector<T>>

GetSharedWeights()

Gets shared weights for all operations

public Dictionary<string, Vector<T>> GetSharedWeights()

Returns

Dictionary<string, Vector<T>>

GetSharedWeights(string)

Gets shared weights for a specific operation

public Vector<T> GetSharedWeights(string operationKey)

Parameters

operationKey string

Returns

Vector<T>

SampleArchitecture()

Samples an architecture using the controller policy

public (Architecture<T> architecture, T logProb, T entropy) SampleArchitecture()

Returns

(Architecture<T> architecture, T logProb, T entropy)

SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)

Performs algorithm-specific architecture search.

protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)

Parameters

inputs Tensor<T>
targets Tensor<T>
validationInputs Tensor<T>
validationTargets Tensor<T>
timeLimit TimeSpan
cancellationToken CancellationToken

Returns

Architecture<T>

UpdateController(T, T, T)

Updates controller using REINFORCE policy gradient

public void UpdateController(T reward, T logProb, T entropy)

Parameters

reward T
logProb T
entropy T