Table of Contents

Class PCDARTS<T>

Namespace
AiDotNet.AutoML.NAS
Assembly
AiDotNet.dll

Partial Channel Connections for Memory-Efficient Differentiable Architecture Search. PC-DARTS reduces memory consumption by sampling only a subset of channels during the search, making it more scalable to larger search spaces and datasets.

Reference: "PC-DARTS: Partial Channel Connections for Memory-Efficient Architecture Search" (ICLR 2020)

public class PCDARTS<T> : NasAutoMLModelBase<T>, IAutoMLModel<T, Tensor<T>, Tensor<T>>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations

Inheritance
AutoMLModelBase<T, Tensor<T>, Tensor<T>>
PCDARTS<T>
Implements
IAutoMLModel<T, Tensor<T>, Tensor<T>>
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Constructors

PCDARTS(SearchSpaceBase<T>, int, double, bool)

public PCDARTS(SearchSpaceBase<T> searchSpace, int numNodes = 4, double channelSamplingRatio = 0.25, bool useEdgeNormalization = true)

Parameters

searchSpace SearchSpaceBase<T>
numNodes int
channelSamplingRatio double
useEdgeNormalization bool

Properties

NasNumNodes

Gets the number of nodes to search over.

protected override int NasNumNodes { get; }

Property Value

int

NasSearchSpace

Gets the NAS search space.

protected override SearchSpaceBase<T> NasSearchSpace { get; }

Property Value

SearchSpaceBase<T>

NumOps

Gets the numeric operations provider for T.

protected override INumericOperations<T> NumOps { get; }

Property Value

INumericOperations<T>

Methods

ApplyEdgeNormalization(Matrix<T>)

Applies edge normalization to prevent operation collapse

public Matrix<T> ApplyEdgeNormalization(Matrix<T> alpha)

Parameters

alpha Matrix<T>

Returns

Matrix<T>

CreateInstanceForCopy()

Factory method for creating a new instance for deep copy. Derived classes must implement this to return a new instance of themselves. This ensures each copy has its own collections and lock object.

protected override AutoMLModelBase<T, Tensor<T>, Tensor<T>> CreateInstanceForCopy()

Returns

AutoMLModelBase<T, Tensor<T>, Tensor<T>>

A fresh instance of the derived class with default parameters

Remarks

When implementing this method, derived classes should create a fresh instance with default parameters, and should not attempt to preserve runtime or initialization state from the original instance. The deep copy logic will transfer relevant state (trial history, search space, etc.) after construction.

DeriveArchitecture()

Derives the discrete architecture

public Architecture<T> DeriveArchitecture()

Returns

Architecture<T>

GetArchitectureGradients()

Gets architecture gradients

public List<Matrix<T>> GetArchitectureGradients()

Returns

List<Matrix<T>>

GetArchitectureParameters()

Gets architecture parameters for optimization

public List<Matrix<T>> GetArchitectureParameters()

Returns

List<Matrix<T>>

GetChannelSamplingRatio()

Gets the channel sampling ratio

public double GetChannelSamplingRatio()

Returns

double

GetMemorySavingsRatio()

Gets the memory savings ratio compared to standard DARTS

public double GetMemorySavingsRatio()

Returns

double

SampleChannels(int)

Samples a subset of channels for partial channel connections

public List<int> SampleChannels(int totalChannels)

Parameters

totalChannels int

Returns

List<int>

SearchArchitecture(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>, TimeSpan, CancellationToken)

Performs algorithm-specific architecture search.

protected override Architecture<T> SearchArchitecture(Tensor<T> inputs, Tensor<T> targets, Tensor<T> validationInputs, Tensor<T> validationTargets, TimeSpan timeLimit, CancellationToken cancellationToken)

Parameters

inputs Tensor<T>
targets Tensor<T>
validationInputs Tensor<T>
validationTargets Tensor<T>
timeLimit TimeSpan
cancellationToken CancellationToken

Returns

Architecture<T>