Class VisionTransformer<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Implements the Vision Transformer (ViT) architecture for image classification tasks.
public class VisionTransformer<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for computations (typically float or double).
- Inheritance
-
VisionTransformer<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
The Vision Transformer applies transformer architecture, originally designed for natural language processing, to computer vision tasks. It divides images into fixed-size patches, linearly embeds them, adds positional embeddings, and processes the sequence through transformer encoder layers.
For Beginners: The Vision Transformer (ViT) is a modern approach to understanding images using transformers.
Unlike traditional neural networks that process images pixel by pixel or with sliding windows (convolutions), ViT treats an image like a sentence of words:
- First, it cuts the image into small square patches (like breaking a sentence into words)
- Each patch gets converted to a numerical representation (like word embeddings)
- Position information is added so the model knows where each patch came from
- A special classification token is added to gather information about the whole image
- Transformer layers process all patches together, learning relationships between them
- Finally, the classification token's output is used to predict the image class
This approach has been very successful and often outperforms traditional convolutional neural networks, especially when trained on large datasets.
Constructors
VisionTransformer(NeuralNetworkArchitecture<T>, int, int, int, int, int, int, int, int, int, ILossFunction<T>?)
Creates a new Vision Transformer with the specified configuration.
public VisionTransformer(NeuralNetworkArchitecture<T> architecture, int imageHeight, int imageWidth, int channels, int patchSize, int numClasses, int hiddenDim = 768, int numLayers = 12, int numHeads = 12, int mlpDim = 3072, ILossFunction<T>? lossFunction = null)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture defining the network structure.
imageHeightintThe height of input images.
imageWidthintThe width of input images.
channelsintThe number of color channels (e.g., 3 for RGB).
patchSizeintThe size of each square patch.
numClassesintThe number of output classes.
hiddenDimintThe dimension of embeddings (default: 768).
numLayersintThe number of transformer encoder layers (default: 12).
numHeadsintThe number of attention heads (default: 12).
mlpDimintThe dimension of the feed-forward network (default: 3072).
lossFunctionILossFunction<T>The loss function to use (defaults to categorical cross-entropy if null).
Remarks
For Beginners: This constructor creates a Vision Transformer with your specifications.
The parameters let you customize:
- Image dimensions: Size and channels of input images
- Patch size: How to divide images (common values: 16 or 32)
- Architecture depth: More layers = more capacity but slower and needs more data
- Hidden dimensions: Larger = more expressive but more parameters
- Attention heads: More heads = more diverse attention patterns
Common configurations:
- ViT-Base: hiddenDim=768, numLayers=12, numHeads=12, mlpDim=3072
- ViT-Large: hiddenDim=1024, numLayers=24, numHeads=16, mlpDim=4096
- ViT-Huge: hiddenDim=1280, numLayers=32, numHeads=16, mlpDim=5120
Properties
ParameterCount
Gets the total number of parameters in the model.
public override int ParameterCount { get; }
Property Value
SupportsTraining
Indicates whether this network supports training.
public override bool SupportsTraining { get; }
Property Value
Methods
CreateNewInstance()
Creates a new instance of the Vision Transformer.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new Vision Transformer instance with the same configuration.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes Vision Transformer-specific data.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read data from.
GetModelMetadata()
Gets the model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
Metadata describing the Vision Transformer model.
Remarks
For Beginners: This method returns information about the model like its type, parameter count, and configuration. This is useful for documentation and debugging.
GetParameters()
Gets all model parameters in a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing CLS token, positional embeddings, and all layer parameters in sequence.
Remarks
This method returns parameters in the exact order expected by UpdateParameters: 1. CLS token vector 2. Positional embeddings (flattened row-major) 3. Parameters from each transformer layer in sequence
InitializeLayers()
Initializes the layers of the Vision Transformer.
protected override void InitializeLayers()
Predict(Tensor<T>)
Makes a prediction using the Vision Transformer.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input image tensor with shape [batch, channels, height, width].
Returns
- Tensor<T>
The predicted class probabilities with shape [batch, num_classes].
Remarks
For Beginners: This method processes an image to predict its class.
The prediction process:
- Convert the image into patches and embed them
- Add a classification token to the beginning of the sequence
- Add positional embeddings so the model knows where each patch came from
- Process through transformer encoder layers
- Extract the classification token's representation
- Pass through the classification head to get class probabilities
The output is a probability distribution over classes (values sum to 1), where higher values indicate more confidence in that class.
SerializeNetworkSpecificData(BinaryWriter)
Serializes Vision Transformer-specific data.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write data to.
Train(Tensor<T>, Tensor<T>)
Trains the Vision Transformer on a single input-output pair.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input image tensor.
expectedOutputTensor<T>The expected output (class labels or probabilities).
Remarks
For Beginners: This method trains the Vision Transformer on one example.
During training:
- Forward pass: Make a prediction using the current parameters
- Calculate loss: Measure how wrong the prediction is
- Backward pass: Calculate gradients for all parameters
- Update parameters: Adjust weights and biases to improve performance
This is typically called many times with different images to train the model.
UpdateParameters(Vector<T>)
Updates the network's parameters with new values.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameter values.
Remarks
For Beginners: This method sets all the network's internal values at once. This is typically used when loading a saved model or when an optimizer computes improved parameter values.