Table of Contents

Class AttentiveNAS<T>

Namespace
AiDotNet.AutoML.NAS
Assembly
AiDotNet.dll

AttentiveNAS: Improving Neural Architecture Search via Attentive Sampling. Uses an attention-based meta-network to guide the sampling of sub-networks, focusing search on promising regions of the architecture space.

Reference: "AttentiveNAS: Improving Neural Architecture Search via Attentive Sampling" (CVPR 2021)

public class AttentiveNAS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations

Inheritance
AutoMLModelBase<T, Tensor<T>, Tensor<T>>
AttentiveNAS<T>
Implements
IAutoMLModel<T, Tensor<T>, Tensor<T>>
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Constructors

AttentiveNAS(SearchSpaceBase<T>, List<int>?, List<double>?, List<int>?, int)

public AttentiveNAS(SearchSpaceBase<T> searchSpace, List<int>? elasticDepths = null, List<double>? elasticWidthMultipliers = null, List<int>? elasticKernelSizes = null, int attentionHiddenSize = 128)

Parameters

searchSpace SearchSpaceBase<T>
elasticDepths List<int>
elasticWidthMultipliers List<double>
elasticKernelSizes List<int>
attentionHiddenSize int

Properties

NasNumNodes

Gets the number of nodes to search over.

protected override int NasNumNodes { get; }

Property Value

int

NasSearchSpace

Gets the NAS search space.

protected override SearchSpaceBase<T> NasSearchSpace { get; }

Property Value

SearchSpaceBase<T>

NumOps

Gets the numeric operations provider for T.

protected override INumericOperations<T> NumOps { get; }

Property Value

INumericOperations<T>

Methods

AttentiveSample(Vector<T>)

Samples architecture using attention-based sampling strategy. The attention module learns to focus on high-performing architecture regions.

public AttentiveNASConfig<T> AttentiveSample(Vector<T> contextVector)

Parameters

contextVector Vector<T>

Returns

AttentiveNASConfig<T>

CreateContextVector()

Creates a context vector from recent architecture performance history

public Vector<T> CreateContextVector()

Returns

Vector<T>

CreateInstanceForCopy()

Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.

protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()

Returns

AutoMLModelBase<T, Tensor<T>, Tensor<T>>

A fresh instance of the derived class with default parameters

Remarks

When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.

GetAttentionWeights()

Gets the attention weights

public Matrix<T> GetAttentionWeights()

Returns

Matrix<T>

GetPerformanceMemory()

Gets the performance memory

public Dictionary<string, T> GetPerformanceMemory()

Returns

Dictionary<string, T>

Search(HardwareConstraints<T>, int, int, int)

Searches for optimal architecture using attentive sampling

public AttentiveNASConfig<T> Search(HardwareConstraints<T> constraints, int inputChannels, int spatialSize, int numIterations = 100)

Parameters

constraints HardwareConstraints<T>
inputChannels int
spatialSize int
numIterations int

Returns

AttentiveNASConfig<T>

SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)

Performs algorithm-specific architecture search.

protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)

Parameters

inputs Tensor<T>
targets Tensor<T>
validationInputs Tensor<T>
validationTargets Tensor<T>
timeLimit TimeSpan
cancellationToken CancellationToken

Returns

Architecture<T>

UpdateAttention(AttentiveNASConfig<T>, T, T)

Updates the attention module based on architecture performance. High-performing architectures increase attention to similar regions.

public void UpdateAttention(AttentiveNASConfig<T> config, T performance, T learningRate)

Parameters

config AttentiveNASConfig<T>
performance T
learningRate T