Class FBNet<T>
FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable Neural Architecture Search. Uses Gumbel-Softmax with hardware latency constraints to find efficient architectures optimized for specific target devices.
Reference: "FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS" (CVPR 2019)
public class FBNet<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations
- Inheritance
-
FBNet<T>
- Implements
- Inherited Members
-
AutoMLModelBase<T, Tensor<T>, Tensor<T>>.SetModelEvaluator(IModelEvaluator<T, Tensor<T>, Tensor<T>>)
- Extension Methods
Constructors
FBNet(SearchSpaceBase<T>, int, HardwarePlatform, double, double, int, int)
public FBNet(SearchSpaceBase<T> searchSpace, int numLayers = 20, HardwarePlatform targetPlatform = HardwarePlatform.Mobile, double latencyWeight = 0.2, double initialTemperature = 5, int inputChannels = 16, int spatialSize = 224)
Parameters
searchSpaceSearchSpaceBase<T>numLayersinttargetPlatformHardwarePlatformlatencyWeightdoubleinitialTemperaturedoubleinputChannelsintspatialSizeint
Properties
NasNumNodes
Gets the number of nodes to search over.
protected override int NasNumNodes { get; }
Property Value
NasSearchSpace
Gets the NAS search space.
protected override SearchSpaceBase<T> NasSearchSpace { get; }
Property Value
NumOps
Gets the numeric operations provider for T.
protected override INumericOperations<T> NumOps { get; }
Property Value
- INumericOperations<T>
Methods
AnnealTemperature(int, int)
Anneals the temperature during training
public void AnnealTemperature(int currentEpoch, int maxEpochs)
Parameters
ComputeExpectedLatency()
Computes the expected latency cost for the entire architecture
public T ComputeExpectedLatency()
Returns
- T
ComputeTotalLoss(T)
Computes the total loss with latency regularization Loss = Cross-Entropy + λ * log(Latency) Using log(latency) makes the loss more sensitive to changes when latency is small
public T ComputeTotalLoss(T taskLoss)
Parameters
taskLossT
Returns
- T
CreateInstanceForCopy()
Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.
protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()
Returns
- AutoMLModelBase<T, Tensor<T>, Tensor<T>>
A fresh instance of the derived class with default parameters
Remarks
When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.
DeriveArchitecture()
Derives the discrete architecture by selecting the operation with highest probability
public Architecture<T> DeriveArchitecture()
Returns
- Architecture<T>
GetArchitectureCost()
Gets the final architecture's hardware cost breakdown
public HardwareCost<T> GetArchitectureCost()
Returns
- HardwareCost<T>
GetArchitectureGradients()
Gets architecture gradients
public List<Vector<T>> GetArchitectureGradients()
Returns
- List<Vector<T>>
GetArchitectureParameters()
Gets architecture parameters for optimization
public List<Vector<T>> GetArchitectureParameters()
Returns
- List<Vector<T>>
GetTemperature()
Gets current temperature
public T GetTemperature()
Returns
- T
GumbelSoftmax(Vector<T>, bool)
Applies Gumbel-Softmax to layer-wise architecture parameters
public Vector<T> GumbelSoftmax(Vector<T> theta, bool hard = false)
Parameters
thetaVector<T>hardbool
Returns
- Vector<T>
MeetsConstraints()
Checks if the derived architecture meets hardware constraints
public bool MeetsConstraints()
Returns
SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)
Performs algorithm-specific architecture search.
protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)
Parameters
inputsTensor<T>targetsTensor<T>validationInputsTensor<T>validationTargetsTensor<T>timeLimitTimeSpancancellationTokenCancellationToken
Returns
- Architecture<T>
SetConstraints(HardwareConstraints<T>)
Sets hardware constraints for the search
public void SetConstraints(HardwareConstraints<T> constraints)
Parameters
constraintsHardwareConstraints<T>