Class ResNetNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a ResNet (Residual Network) neural network architecture for image classification.
public class ResNetNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
ResNetNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
ResNet networks are deep convolutional neural networks that introduced skip connections (residual connections) to enable training of very deep networks. They learn residual functions with reference to the layer inputs, rather than learning unreferenced functions directly.
For Beginners: ResNet networks revolutionized deep learning by solving the "vanishing gradient" problem that made very deep networks hard to train. Key benefits include:
- Can train networks with 100+ layers (compared to ~20 layers for earlier architectures)
- Skip connections allow gradients to flow more easily during training
- Each block learns the "residual" (difference) rather than the complete transformation
- Winner of ImageNet 2015 competition with top-5 error of 3.57%
Architecture Variants:
- ResNet18/34: Use BasicBlock (2 conv layers per block)
- ResNet50/101/152: Use BottleneckBlock (1x1-3x3-1x1 conv pattern) for efficiency
Typical Usage:
// Create ResNet50 for 1000-class classification
var config = new ResNetConfiguration(ResNetVariant.ResNet50, numClasses: 1000);
var architecture = new NeuralNetworkArchitecture<float>(
inputType: InputType.ThreeDimensional,
inputHeight: 224,
inputWidth: 224,
inputDepth: 3,
outputSize: 1000,
taskType: NeuralNetworkTaskType.MultiClassClassification);
var network = new ResNetNetwork<float>(architecture, config);
Constructors
ResNetNetwork(NeuralNetworkArchitecture<T>, ResNetConfiguration, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the ResNetNetwork class.
public ResNetNetwork(NeuralNetworkArchitecture<T> architecture, ResNetConfiguration configuration, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture defining the structure of the neural network.
configurationResNetConfigurationThe ResNet-specific configuration.
optimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for training (default: Adam).
lossFunctionILossFunction<T>Optional loss function (default: based on task type).
maxGradNormdoubleMaximum gradient norm for gradient clipping (default: 1.0).
Remarks
ResNet networks require three-dimensional input data (channels, height, width).
For Beginners: When creating a ResNet network, you need to provide:
- An architecture object that describes the input/output dimensions
- A configuration that specifies which ResNet variant to use
- Optionally, custom optimizer and loss function (good defaults are provided)
Exceptions
- InvalidInputTypeException
Thrown when the input type is not three-dimensional.
- ArgumentNullException
Thrown when configuration is null.
Properties
NumClasses
Gets the number of output classes for classification.
public int NumClasses { get; }
Property Value
UsesBottleneck
Gets whether this variant uses bottleneck blocks.
public bool UsesBottleneck { get; }
Property Value
Remarks
For Beginners: ResNet50 and deeper use bottleneck blocks (1x1-3x3-1x1 convolutions) which are more parameter efficient than the basic blocks used in ResNet18/34.
Variant
Gets the ResNet variant being used.
public ResNetVariant Variant { get; }
Property Value
Remarks
For Beginners: The variant determines how deep the network is (ResNet18, 34, 50, 101, or 152) and which block type is used (BasicBlock for 18/34, BottleneckBlock for 50/101/152).
Methods
Backward(Tensor<T>)
Performs a backward pass through the network to calculate gradients.
public Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>The gradient of the loss with respect to the network's output.
Returns
- Tensor<T>
The gradient of the loss with respect to the network's input.
CreateNewInstance()
Creates a new instance of the ResNet network model.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes ResNet network-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
ForTesting(int, int)
Creates a minimal ResNet network optimized for fast test execution.
public static ResNetNetwork<T> ForTesting(int numClasses = 10, int inputChannels = 3)
Parameters
numClassesintThe number of output classes.
inputChannelsintThe number of input channels (default: 3 for RGB).
Returns
- ResNetNetwork<T>
A minimal ResNet network for testing.
Remarks
Uses ResNet18 (smallest variant) with 32x32 input resolution, resulting in significantly fewer layers than standard variants. Construction time is typically under 50ms.
Forward(Tensor<T>)
Performs a forward pass through the ResNet network with the given input tensor.
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process (shape: [channels, height, width] for a single example, or [batch, channels, height, width] for a batch).
Returns
- Tensor<T>
The output tensor after processing through all layers.
Remarks
The forward pass sequentially processes the input through each layer of the network: initial conv, residual blocks, global pooling, and classification layer.
For Beginners: This is how the network makes predictions. You give it an image (as a tensor), and it processes it through all the ResNet layers to produce a prediction. The output contains probabilities for each class.
Exceptions
- TensorShapeMismatchException
Thrown when the input shape doesn't match expected shape.
GetModelMetadata()
Retrieves metadata about the ResNet network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the network.
GetParameterCount()
Gets the total number of trainable parameters in the network.
public int GetParameterCount()
Returns
- int
The total parameter count.
Remarks
ResNet networks have fewer parameters than VGG despite being deeper:
- ResNet18: ~11.7 million parameters
- ResNet34: ~21.8 million parameters
- ResNet50: ~25.6 million parameters
- ResNet101: ~44.5 million parameters
- ResNet152: ~60.2 million parameters
InitializeLayers()
Initializes the layers of the ResNet network based on the configuration.
protected override sealed void InitializeLayers()
Remarks
This method either uses custom layers provided in the architecture or creates the standard ResNet layers based on the configuration.
For Beginners: This method builds the ResNet network layer by layer:
- Initial 7x7 convolution with stride 2
- Max pooling 3x3 with stride 2
- Four stages of residual blocks (conv2_x through conv5_x)
- Global average pooling
- Fully connected classification layer
Predict(Tensor<T>)
Makes a prediction using the ResNet network for the given input.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to make a prediction for.
Returns
- Tensor<T>
The predicted output tensor containing class probabilities.
SerializeNetworkSpecificData(BinaryWriter)
Serializes ResNet network-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
Train(Tensor<T>, Tensor<T>)
Trains the ResNet network using the provided input and expected output.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor for training.
expectedOutputTensor<T>The expected output tensor (one-hot encoded class labels).
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for the network.