Class AttentiveNAS<T>
AttentiveNAS: Improving Neural Architecture Search via Attentive Sampling. Uses an attention-based meta-network to guide the sampling of sub-networks, focusing search on promising regions of the architecture space.
Reference: "AttentiveNAS: Improving Neural Architecture Search via Attentive Sampling" (CVPR 2021)
public class AttentiveNAS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations
- Inheritance
-
AttentiveNAS<T>
- Implements
- Inherited Members
-
AutoMLModelBase<T, Tensor<T>, Tensor<T>>.SetModelEvaluator(IModelEvaluator<T, Tensor<T>, Tensor<T>>)
- Extension Methods
Constructors
AttentiveNAS(SearchSpaceBase<T>, List<int>?, List<double>?, List<int>?, int)
public AttentiveNAS(SearchSpaceBase<T> searchSpace, List<int>? elasticDepths = null, List<double>? elasticWidthMultipliers = null, List<int>? elasticKernelSizes = null, int attentionHiddenSize = 128)
Parameters
searchSpaceSearchSpaceBase<T>elasticDepthsList<int>elasticWidthMultipliersList<double>elasticKernelSizesList<int>attentionHiddenSizeint
Properties
NasNumNodes
Gets the number of nodes to search over.
protected override int NasNumNodes { get; }
Property Value
NasSearchSpace
Gets the NAS search space.
protected override SearchSpaceBase<T> NasSearchSpace { get; }
Property Value
NumOps
Gets the numeric operations provider for T.
protected override INumericOperations<T> NumOps { get; }
Property Value
- INumericOperations<T>
Methods
AttentiveSample(Vector<T>)
Samples architecture using attention-based sampling strategy. The attention module learns to focus on high-performing architecture regions.
public AttentiveNASConfig<T> AttentiveSample(Vector<T> contextVector)
Parameters
contextVectorVector<T>
Returns
CreateContextVector()
Creates a context vector from recent architecture performance history
public Vector<T> CreateContextVector()
Returns
- Vector<T>
CreateInstanceForCopy()
Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.
protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()
Returns
- AutoMLModelBase<T, Tensor<T>, Tensor<T>>
A fresh instance of the derived class with default parameters
Remarks
When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.
GetAttentionWeights()
Gets the attention weights
public Matrix<T> GetAttentionWeights()
Returns
- Matrix<T>
GetPerformanceMemory()
Gets the performance memory
public Dictionary<string, T> GetPerformanceMemory()
Returns
- Dictionary<string, T>
Search(HardwareConstraints<T>, int, int, int)
Searches for optimal architecture using attentive sampling
public AttentiveNASConfig<T> Search(HardwareConstraints<T> constraints, int inputChannels, int spatialSize, int numIterations = 100)
Parameters
constraintsHardwareConstraints<T>inputChannelsintspatialSizeintnumIterationsint
Returns
SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)
Performs algorithm-specific architecture search.
protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)
Parameters
inputsTensor<T>targetsTensor<T>validationInputsTensor<T>validationTargetsTensor<T>timeLimitTimeSpancancellationTokenCancellationToken
Returns
- Architecture<T>
UpdateAttention(AttentiveNASConfig<T>, T, T)
Updates the attention module based on architecture performance. High-performing architectures increase attention to similar regions.
public void UpdateAttention(AttentiveNASConfig<T> config, T performance, T learningRate)
Parameters
configAttentiveNASConfig<T>performanceTlearningRateT