Table of Contents

Class VisionTransformer<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Implements the Vision Transformer (ViT) architecture for image classification tasks.

public class VisionTransformer<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for computations (typically float or double).

Inheritance
VisionTransformer<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

The Vision Transformer applies transformer architecture, originally designed for natural language processing, to computer vision tasks. It divides images into fixed-size patches, linearly embeds them, adds positional embeddings, and processes the sequence through transformer encoder layers.

For Beginners: The Vision Transformer (ViT) is a modern approach to understanding images using transformers.

Unlike traditional neural networks that process images pixel by pixel or with sliding windows (convolutions), ViT treats an image like a sentence of words:

  • First, it cuts the image into small square patches (like breaking a sentence into words)
  • Each patch gets converted to a numerical representation (like word embeddings)
  • Position information is added so the model knows where each patch came from
  • A special classification token is added to gather information about the whole image
  • Transformer layers process all patches together, learning relationships between them
  • Finally, the classification token's output is used to predict the image class

This approach has been very successful and often outperforms traditional convolutional neural networks, especially when trained on large datasets.

Constructors

VisionTransformer(NeuralNetworkArchitecture<T>, int, int, int, int, int, int, int, int, int, ILossFunction<T>?)

Creates a new Vision Transformer with the specified configuration.

public VisionTransformer(NeuralNetworkArchitecture<T> architecture, int imageHeight, int imageWidth, int channels, int patchSize, int numClasses, int hiddenDim = 768, int numLayers = 12, int numHeads = 12, int mlpDim = 3072, ILossFunction<T>? lossFunction = null)

Parameters

architecture NeuralNetworkArchitecture<T>

The architecture defining the network structure.

imageHeight int

The height of input images.

imageWidth int

The width of input images.

channels int

The number of color channels (e.g., 3 for RGB).

patchSize int

The size of each square patch.

numClasses int

The number of output classes.

hiddenDim int

The dimension of embeddings (default: 768).

numLayers int

The number of transformer encoder layers (default: 12).

numHeads int

The number of attention heads (default: 12).

mlpDim int

The dimension of the feed-forward network (default: 3072).

lossFunction ILossFunction<T>

The loss function to use (defaults to categorical cross-entropy if null).

Remarks

For Beginners: This constructor creates a Vision Transformer with your specifications.

The parameters let you customize:

  • Image dimensions: Size and channels of input images
  • Patch size: How to divide images (common values: 16 or 32)
  • Architecture depth: More layers = more capacity but slower and needs more data
  • Hidden dimensions: Larger = more expressive but more parameters
  • Attention heads: More heads = more diverse attention patterns

Common configurations:

  • ViT-Base: hiddenDim=768, numLayers=12, numHeads=12, mlpDim=3072
  • ViT-Large: hiddenDim=1024, numLayers=24, numHeads=16, mlpDim=4096
  • ViT-Huge: hiddenDim=1280, numLayers=32, numHeads=16, mlpDim=5120

Properties

ParameterCount

Gets the total number of parameters in the model.

public override int ParameterCount { get; }

Property Value

int

SupportsTraining

Indicates whether this network supports training.

public override bool SupportsTraining { get; }

Property Value

bool

Methods

CreateNewInstance()

Creates a new instance of the Vision Transformer.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

A new Vision Transformer instance with the same configuration.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes Vision Transformer-specific data.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

The binary reader to read data from.

GetModelMetadata()

Gets the model metadata.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

Metadata describing the Vision Transformer model.

Remarks

For Beginners: This method returns information about the model like its type, parameter count, and configuration. This is useful for documentation and debugging.

GetParameters()

Gets all model parameters in a single vector.

public override Vector<T> GetParameters()

Returns

Vector<T>

A vector containing CLS token, positional embeddings, and all layer parameters in sequence.

Remarks

This method returns parameters in the exact order expected by UpdateParameters: 1. CLS token vector 2. Positional embeddings (flattened row-major) 3. Parameters from each transformer layer in sequence

InitializeLayers()

Initializes the layers of the Vision Transformer.

protected override void InitializeLayers()

Predict(Tensor<T>)

Makes a prediction using the Vision Transformer.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input image tensor with shape [batch, channels, height, width].

Returns

Tensor<T>

The predicted class probabilities with shape [batch, num_classes].

Remarks

For Beginners: This method processes an image to predict its class.

The prediction process:

  1. Convert the image into patches and embed them
  2. Add a classification token to the beginning of the sequence
  3. Add positional embeddings so the model knows where each patch came from
  4. Process through transformer encoder layers
  5. Extract the classification token's representation
  6. Pass through the classification head to get class probabilities

The output is a probability distribution over classes (values sum to 1), where higher values indicate more confidence in that class.

SerializeNetworkSpecificData(BinaryWriter)

Serializes Vision Transformer-specific data.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

The binary writer to write data to.

Train(Tensor<T>, Tensor<T>)

Trains the Vision Transformer on a single input-output pair.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input image tensor.

expectedOutput Tensor<T>

The expected output (class labels or probabilities).

Remarks

For Beginners: This method trains the Vision Transformer on one example.

During training:

  1. Forward pass: Make a prediction using the current parameters
  2. Calculate loss: Measure how wrong the prediction is
  3. Backward pass: Calculate gradients for all parameters
  4. Update parameters: Adjust weights and biases to improve performance

This is typically called many times with different images to train the model.

UpdateParameters(Vector<T>)

Updates the network's parameters with new values.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

The new parameter values.

Remarks

For Beginners: This method sets all the network's internal values at once. This is typically used when loading a saved model or when an optimizer computes improved parameter values.