Class PretrainedModelLoader<T>
- Namespace
- AiDotNet.ModelLoading
- Assembly
- AiDotNet.dll
Loads pretrained models from various sources.
public class PretrainedModelLoader<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
PretrainedModelLoader<T>
- Inherited Members
Remarks
This class provides methods to load pretrained model weights into our model classes. It supports loading from local SafeTensors files and handles weight name mapping between different model formats.
For Beginners: This is your gateway to using pretrained models.
Instead of training models from scratch (which requires massive datasets and compute resources), you can load pretrained weights that others have trained.
Example usage:
var loader = new PretrainedModelLoader<float>();
// Load a pretrained VAE
var vae = new StandardVAE<float>();
await loader.LoadVAEWeights(vae, "sd-vae-ft-mse/diffusion_pytorch_model.safetensors");
// Now your VAE is ready for image encoding/decoding!
Constructors
PretrainedModelLoader(bool)
Initializes a new instance of the PretrainedModelLoader class.
public PretrainedModelLoader(bool verbose = false)
Parameters
verboseboolWhether to log loading progress (default: false).
Methods
GetTensorInfo(string)
Gets information about tensors in a SafeTensors file.
public List<TensorMetadata> GetTensorInfo(string path)
Parameters
pathstringPath to the .safetensors file.
Returns
- List<TensorMetadata>
List of tensor metadata.
LoadAllTensors(string)
Loads all tensors from a SafeTensors file.
public Dictionary<string, Tensor<T>> LoadAllTensors(string path)
Parameters
pathstringPath to the .safetensors file.
Returns
- Dictionary<string, Tensor<T>>
Dictionary of all loaded tensors.
LoadTensors(string, IEnumerable<string>)
Loads specific tensors by name from a SafeTensors file.
public Dictionary<string, Tensor<T>> LoadTensors(string path, IEnumerable<string> tensorNames)
Parameters
pathstringPath to the .safetensors file.
tensorNamesIEnumerable<string>Names of tensors to load.
Returns
- Dictionary<string, Tensor<T>>
Dictionary of loaded tensors.
LoadWeights(IWeightLoadable<T>, string, WeightMapping, bool)
Loads weights using a WeightMapping instance.
public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, WeightMapping mapping, bool strict = false)
Parameters
modelIWeightLoadable<T>The model to load weights into.
weightsPathstringPath to the .safetensors file.
mappingWeightMappingWeight mapping to use for name translation.
strictboolIf true, fails when weights can't be loaded.
Returns
- LoadResult
Load result with statistics.
LoadWeights(IWeightLoadable<T>, string, Func<string, string?>?, bool)
Loads weights from a SafeTensors file into any IWeightLoadable model.
public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, Func<string, string?>? mapping = null, bool strict = false)
Parameters
modelIWeightLoadable<T>The model to load weights into (must implement IWeightLoadable).
weightsPathstringPath to the .safetensors file.
mappingFunc<string, string>Optional custom weight mapping function. Maps source names to target names.
strictboolIf true, fails when weights can't be loaded. If false, skips missing weights.
Returns
- LoadResult
Load result with statistics about applied weights.
Remarks
For Beginners: This method takes pretrained weights from a file and loads them into your model. The mapping function translates between the weight names in the file (like "encoder.conv1.weight") and the names your model uses.
Exceptions
- ArgumentNullException
Thrown when model or weightsPath is null.
- FileNotFoundException
Thrown when weights file doesn't exist.
ValidateWeights(string, IEnumerable<string>)
Validates that required tensors exist in a weights file.
public ValidationResult ValidateWeights(string path, IEnumerable<string> requiredTensorPatterns)
Parameters
pathstringPath to the .safetensors file.
requiredTensorPatternsIEnumerable<string>Patterns for required tensor names.
Returns
- ValidationResult
Validation result.