Table of Contents

Class PretrainedModelLoader<T>

Namespace
AiDotNet.ModelLoading
Assembly
AiDotNet.dll

Loads pretrained models from various sources.

public class PretrainedModelLoader<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
PretrainedModelLoader<T>
Inherited Members

Remarks

This class provides methods to load pretrained model weights into our model classes. It supports loading from local SafeTensors files and handles weight name mapping between different model formats.

For Beginners: This is your gateway to using pretrained models.

Instead of training models from scratch (which requires massive datasets and compute resources), you can load pretrained weights that others have trained.

Example usage:

var loader = new PretrainedModelLoader<float>();

// Load a pretrained VAE
var vae = new StandardVAE<float>();
await loader.LoadVAEWeights(vae, "sd-vae-ft-mse/diffusion_pytorch_model.safetensors");

// Now your VAE is ready for image encoding/decoding!

Constructors

PretrainedModelLoader(bool)

Initializes a new instance of the PretrainedModelLoader class.

public PretrainedModelLoader(bool verbose = false)

Parameters

verbose bool

Whether to log loading progress (default: false).

Methods

GetTensorInfo(string)

Gets information about tensors in a SafeTensors file.

public List<TensorMetadata> GetTensorInfo(string path)

Parameters

path string

Path to the .safetensors file.

Returns

List<TensorMetadata>

List of tensor metadata.

LoadAllTensors(string)

Loads all tensors from a SafeTensors file.

public Dictionary<string, Tensor<T>> LoadAllTensors(string path)

Parameters

path string

Path to the .safetensors file.

Returns

Dictionary<string, Tensor<T>>

Dictionary of all loaded tensors.

LoadTensors(string, IEnumerable<string>)

Loads specific tensors by name from a SafeTensors file.

public Dictionary<string, Tensor<T>> LoadTensors(string path, IEnumerable<string> tensorNames)

Parameters

path string

Path to the .safetensors file.

tensorNames IEnumerable<string>

Names of tensors to load.

Returns

Dictionary<string, Tensor<T>>

Dictionary of loaded tensors.

LoadWeights(IWeightLoadable<T>, string, WeightMapping, bool)

Loads weights using a WeightMapping instance.

public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, WeightMapping mapping, bool strict = false)

Parameters

model IWeightLoadable<T>

The model to load weights into.

weightsPath string

Path to the .safetensors file.

mapping WeightMapping

Weight mapping to use for name translation.

strict bool

If true, fails when weights can't be loaded.

Returns

LoadResult

Load result with statistics.

LoadWeights(IWeightLoadable<T>, string, Func<string, string?>?, bool)

Loads weights from a SafeTensors file into any IWeightLoadable model.

public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, Func<string, string?>? mapping = null, bool strict = false)

Parameters

model IWeightLoadable<T>

The model to load weights into (must implement IWeightLoadable).

weightsPath string

Path to the .safetensors file.

mapping Func<string, string>

Optional custom weight mapping function. Maps source names to target names.

strict bool

If true, fails when weights can't be loaded. If false, skips missing weights.

Returns

LoadResult

Load result with statistics about applied weights.

Remarks

For Beginners: This method takes pretrained weights from a file and loads them into your model. The mapping function translates between the weight names in the file (like "encoder.conv1.weight") and the names your model uses.

Exceptions

ArgumentNullException

Thrown when model or weightsPath is null.

FileNotFoundException

Thrown when weights file doesn't exist.

ValidateWeights(string, IEnumerable<string>)

Validates that required tensors exist in a weights file.

public ValidationResult ValidateWeights(string path, IEnumerable<string> requiredTensorPatterns)

Parameters

path string

Path to the .safetensors file.

requiredTensorPatterns IEnumerable<string>

Patterns for required tensor names.

Returns

ValidationResult

Validation result.