Table of Contents

Class ResNetNetwork<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a ResNet (Residual Network) neural network architecture for image classification.

public class ResNetNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations (typically float or double).

Inheritance
ResNetNetwork<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

ResNet networks are deep convolutional neural networks that introduced skip connections (residual connections) to enable training of very deep networks. They learn residual functions with reference to the layer inputs, rather than learning unreferenced functions directly.

For Beginners: ResNet networks revolutionized deep learning by solving the "vanishing gradient" problem that made very deep networks hard to train. Key benefits include:

  • Can train networks with 100+ layers (compared to ~20 layers for earlier architectures)
  • Skip connections allow gradients to flow more easily during training
  • Each block learns the "residual" (difference) rather than the complete transformation
  • Winner of ImageNet 2015 competition with top-5 error of 3.57%

Architecture Variants:

  • ResNet18/34: Use BasicBlock (2 conv layers per block)
  • ResNet50/101/152: Use BottleneckBlock (1x1-3x3-1x1 conv pattern) for efficiency

Typical Usage:

// Create ResNet50 for 1000-class classification
var config = new ResNetConfiguration(ResNetVariant.ResNet50, numClasses: 1000);
var architecture = new NeuralNetworkArchitecture<float>(
    inputType: InputType.ThreeDimensional,
    inputHeight: 224,
    inputWidth: 224,
    inputDepth: 3,
    outputSize: 1000,
    taskType: NeuralNetworkTaskType.MultiClassClassification);
var network = new ResNetNetwork<float>(architecture, config);

Constructors

ResNetNetwork(NeuralNetworkArchitecture<T>, ResNetConfiguration, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)

Initializes a new instance of the ResNetNetwork class.

public ResNetNetwork(NeuralNetworkArchitecture<T> architecture, ResNetConfiguration configuration, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)

Parameters

architecture NeuralNetworkArchitecture<T>

The architecture defining the structure of the neural network.

configuration ResNetConfiguration

The ResNet-specific configuration.

optimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for training (default: Adam).

lossFunction ILossFunction<T>

Optional loss function (default: based on task type).

maxGradNorm double

Maximum gradient norm for gradient clipping (default: 1.0).

Remarks

ResNet networks require three-dimensional input data (channels, height, width).

For Beginners: When creating a ResNet network, you need to provide:

  • An architecture object that describes the input/output dimensions
  • A configuration that specifies which ResNet variant to use
  • Optionally, custom optimizer and loss function (good defaults are provided)

Exceptions

InvalidInputTypeException

Thrown when the input type is not three-dimensional.

ArgumentNullException

Thrown when configuration is null.

Properties

NumClasses

Gets the number of output classes for classification.

public int NumClasses { get; }

Property Value

int

UsesBottleneck

Gets whether this variant uses bottleneck blocks.

public bool UsesBottleneck { get; }

Property Value

bool

Remarks

For Beginners: ResNet50 and deeper use bottleneck blocks (1x1-3x3-1x1 convolutions) which are more parameter efficient than the basic blocks used in ResNet18/34.

Variant

Gets the ResNet variant being used.

public ResNetVariant Variant { get; }

Property Value

ResNetVariant

Remarks

For Beginners: The variant determines how deep the network is (ResNet18, 34, 50, 101, or 152) and which block type is used (BasicBlock for 18/34, BottleneckBlock for 50/101/152).

Methods

Backward(Tensor<T>)

Performs a backward pass through the network to calculate gradients.

public Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

The gradient of the loss with respect to the network's output.

Returns

Tensor<T>

The gradient of the loss with respect to the network's input.

CreateNewInstance()

Creates a new instance of the ResNet network model.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

DeserializeNetworkSpecificData(BinaryReader)

Deserializes ResNet network-specific data from a binary reader.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

ForTesting(int, int)

Creates a minimal ResNet network optimized for fast test execution.

public static ResNetNetwork<T> ForTesting(int numClasses = 10, int inputChannels = 3)

Parameters

numClasses int

The number of output classes.

inputChannels int

The number of input channels (default: 3 for RGB).

Returns

ResNetNetwork<T>

A minimal ResNet network for testing.

Remarks

Uses ResNet18 (smallest variant) with 32x32 input resolution, resulting in significantly fewer layers than standard variants. Construction time is typically under 50ms.

Forward(Tensor<T>)

Performs a forward pass through the ResNet network with the given input tensor.

public Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor to process (shape: [channels, height, width] for a single example, or [batch, channels, height, width] for a batch).

Returns

Tensor<T>

The output tensor after processing through all layers.

Remarks

The forward pass sequentially processes the input through each layer of the network: initial conv, residual blocks, global pooling, and classification layer.

For Beginners: This is how the network makes predictions. You give it an image (as a tensor), and it processes it through all the ResNet layers to produce a prediction. The output contains probabilities for each class.

Exceptions

TensorShapeMismatchException

Thrown when the input shape doesn't match expected shape.

GetModelMetadata()

Retrieves metadata about the ResNet network model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetaData object containing information about the network.

GetParameterCount()

Gets the total number of trainable parameters in the network.

public int GetParameterCount()

Returns

int

The total parameter count.

Remarks

ResNet networks have fewer parameters than VGG despite being deeper:

  • ResNet18: ~11.7 million parameters
  • ResNet34: ~21.8 million parameters
  • ResNet50: ~25.6 million parameters
  • ResNet101: ~44.5 million parameters
  • ResNet152: ~60.2 million parameters

InitializeLayers()

Initializes the layers of the ResNet network based on the configuration.

protected override sealed void InitializeLayers()

Remarks

This method either uses custom layers provided in the architecture or creates the standard ResNet layers based on the configuration.

For Beginners: This method builds the ResNet network layer by layer:

  1. Initial 7x7 convolution with stride 2
  2. Max pooling 3x3 with stride 2
  3. Four stages of residual blocks (conv2_x through conv5_x)
  4. Global average pooling
  5. Fully connected classification layer

Predict(Tensor<T>)

Makes a prediction using the ResNet network for the given input.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input tensor to make a prediction for.

Returns

Tensor<T>

The predicted output tensor containing class probabilities.

SerializeNetworkSpecificData(BinaryWriter)

Serializes ResNet network-specific data to a binary writer.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

Train(Tensor<T>, Tensor<T>)

Trains the ResNet network using the provided input and expected output.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input tensor for training.

expectedOutput Tensor<T>

The expected output tensor (one-hot encoded class labels).

UpdateParameters(Vector<T>)

Updates the parameters of all layers in the network.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for the network.