Class DocumentNeuralNetworkBase<T>
Base class for document-focused neural networks that can operate in both ONNX inference and native training modes.
public abstract class DocumentNeuralNetworkBase<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DocumentNeuralNetworkBase<T>
- Implements
- Derived
- Inherited Members
- Extension Methods
Remarks
This class extends NeuralNetworkBase<T> to provide document-specific functionality while maintaining full integration with the AiDotNet neural network infrastructure.
For Beginners: Document neural networks process images of documents (scanned pages, PDFs, photos). This base class provides:
- Support for pre-trained ONNX models (fast inference with existing models)
- Full training capability from scratch (like other neural networks)
- Document preprocessing utilities (normalization, resizing, etc.)
- Layout-aware feature extraction
- Integration with text encoding for layout-aware models
You can use this class in two ways:
- Load a pre-trained ONNX model for quick inference
- Build and train a new model from scratch
Constructors
DocumentNeuralNetworkBase(NeuralNetworkArchitecture<T>, ILossFunction<T>?, double)
Initializes a new instance of the DocumentNeuralNetworkBase class with the specified architecture.
protected DocumentNeuralNetworkBase(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, double maxGradNorm = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture.
lossFunctionILossFunction<T>The loss function to use. If null, CrossEntropyLoss is used.
maxGradNormdoubleMaximum gradient norm for gradient clipping.
Properties
DefaultLossFunction
Gets the default loss function for this model.
public override ILossFunction<T> DefaultLossFunction { get; }
Property Value
ImageSize
Gets the expected input image size for this model.
public int ImageSize { get; protected set; }
Property Value
Remarks
Common values: 224 (ViT base), 384, 448, 512, 768, 1024. Document images should be resized to match this size.
IsOnnxMode
Gets whether this model is running in ONNX inference mode.
public bool IsOnnxMode { get; }
Property Value
Remarks
When true, the model uses pre-trained ONNX weights for inference. When false, the model uses native layers and can be trained.
MaxSequenceLength
Gets the maximum text sequence length for layout-aware models.
public int MaxSequenceLength { get; protected set; }
Property Value
Remarks
For models that process text tokens (like LayoutLM), this is the maximum number of tokens that can be processed. Typical values: 512, 1024, 2048.
OnnxDecoder
Gets or sets the ONNX decoder model (for encoder-decoder architectures).
protected OnnxModel<T>? OnnxDecoder { get; set; }
Property Value
- OnnxModel<T>
OnnxEncoder
Gets or sets the ONNX encoder model (for encoder-decoder architectures).
protected OnnxModel<T>? OnnxEncoder { get; set; }
Property Value
- OnnxModel<T>
OnnxModel
Gets or sets the ONNX model (for single-model architectures).
protected OnnxModel<T>? OnnxModel { get; set; }
Property Value
- OnnxModel<T>
RequiresOCR
Gets whether this model requires OCR preprocessing.
public abstract bool RequiresOCR { get; }
Property Value
Remarks
Layout-aware models (LayoutLM, etc.) require OCR to provide text and bounding boxes. OCR-free models (Donut, Pix2Struct) process raw pixels directly.
SupportedDocumentTypes
Gets the supported document types for this model.
public abstract DocumentType SupportedDocumentTypes { get; }
Property Value
SupportsTraining
Gets whether this network supports training.
public override bool SupportsTraining { get; }
Property Value
Remarks
In ONNX mode, training is not supported - the model is inference-only. In native mode, training is fully supported.
Methods
ApplyDefaultPostprocessing(Tensor<T>)
Applies industry-standard postprocessing defaults for this specific model type.
protected abstract Tensor<T> ApplyDefaultPostprocessing(Tensor<T> modelOutput)
Parameters
modelOutputTensor<T>Raw model output tensor.
Returns
- Tensor<T>
Postprocessed output using model-specific defaults.
Remarks
Each model should implement this with its paper-recommended postprocessing. For example: - Classification models: Softmax + argmax - Detection models: NMS + confidence thresholding - OCR models: CTC decoding or attention decoding
ApplyDefaultPreprocessing(Tensor<T>)
Applies industry-standard preprocessing defaults for this specific model type.
protected abstract Tensor<T> ApplyDefaultPreprocessing(Tensor<T> rawImage)
Parameters
rawImageTensor<T>Raw document image tensor.
Returns
- Tensor<T>
Preprocessed image using model-specific defaults.
Remarks
Each model should implement this with its paper-recommended preprocessing. For example: - TrOCR: Resize to 384x384, normalize with mean=0.5, std=0.5 - LayoutLMv3: Resize to 224x224, ImageNet normalization - Donut: Resize to 2560x1920, normalize to [-1,1]
Dispose(bool)
Disposes of resources used by this model.
protected override void Dispose(bool disposing)
Parameters
disposingboolTrue if disposing managed resources.
EnsureBatchDimension(Tensor<T>)
Adds a batch dimension to a 3D tensor if needed.
protected Tensor<T> EnsureBatchDimension(Tensor<T> tensor)
Parameters
tensorTensor<T>The input tensor.
Returns
- Tensor<T>
A 4D tensor with batch dimension.
Forward(Tensor<T>)
Performs a forward pass through the native neural network layers.
protected virtual Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Preprocessed input tensor.
Returns
- Tensor<T>
Model output tensor.
PostprocessOutput(Tensor<T>)
Postprocesses model output into the final result format.
protected Tensor<T> PostprocessOutput(Tensor<T> modelOutput)
Parameters
modelOutputTensor<T>Raw output from the model.
Returns
- Tensor<T>
Postprocessed output in the expected format.
Remarks
Priority Order: 1. If user configured a pipeline via AiModelBuilder.ConfigurePostprocessing() → use it 2. Otherwise → use industry-standard defaults for this specific model type
For Beginners: Model outputs often need to be transformed into a usable format. You can either let the model use its industry-standard defaults (recommended for most cases), or configure custom postprocessing:
var result = new AiModelBuilder<double, Tensor<double>, Tensor<double>>()
.ConfigurePostprocessing(pipeline => pipeline
.Add(new SoftmaxTransformer<double>())
.Add(new LabelDecoder<double>(labels)))
.Build(X, y);
PreprocessDocument(Tensor<T>)
Preprocesses a raw document image for model input.
protected Tensor<T> PreprocessDocument(Tensor<T> rawImage)
Parameters
rawImageTensor<T>Raw document image tensor [channels, height, width] or [batch, channels, height, width].
Returns
- Tensor<T>
Preprocessed image suitable for model input.
Remarks
Priority Order: 1. If user configured a pipeline via AiModelBuilder.ConfigurePreprocessing() → use it 2. Otherwise → use industry-standard defaults for this specific model type
For Beginners: Raw images need to be transformed before the model can process them. You can either let the model use its industry-standard defaults (recommended for most cases), or configure custom preprocessing:
var result = new AiModelBuilder<double, Tensor<double>, Tensor<double>>()
.ConfigurePreprocessing(pipeline => pipeline
.Add(new ImageResizer<double>(224, 224))
.Add(new ImageNormalizer<double>()))
.Build(X, y);
RunOnnxInference(Tensor<T>)
Runs inference using ONNX model(s).
protected virtual Tensor<T> RunOnnxInference(Tensor<T> input)
Parameters
inputTensor<T>Preprocessed input tensor.
Returns
- Tensor<T>
Model output tensor.
Remarks
Override this method to implement ONNX-specific inference logic for models with complex encoder-decoder or multi-model architectures.
This method expects either OnnxModel or OnnxEncoder/OnnxDecoder to be configured, but not both. When only an encoder is set, the encoded output is returned.
ValidateImageShape(Tensor<T>)
Validates that an input image tensor has the correct shape.
protected void ValidateImageShape(Tensor<T> image)
Parameters
imageTensor<T>The tensor to validate.
Exceptions
- ArgumentNullException
If image is null.
- ArgumentException
If the tensor shape is invalid.