Table of Contents

Class HuggingFaceModelLoader<T>

Namespace
AiDotNet.ModelLoading
Assembly
AiDotNet.dll

Downloads and caches models from HuggingFace Hub.

public class HuggingFaceModelLoader<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
HuggingFaceModelLoader<T>
Inherited Members

Remarks

For Beginners: HuggingFace Hub is like a library of pretrained AI models.

Instead of training models yourself (which requires huge amounts of data and compute), you can download models that others have already trained. This class handles:

  1. Downloading model files from HuggingFace
  2. Caching them locally so you don't re-download every time
  3. Loading the weights into your model

Example usage:

var loader = new HuggingFaceModelLoader<float>();

// Download and cache a pretrained VAE
var files = await loader.DownloadModelAsync("stabilityai/sd-vae-ft-mse");

// Load weights into your model
var vae = new VAEEncoder<float>();
loader.LoadWeights(vae, files["diffusion_pytorch_model.safetensors"]);

Constructors

HuggingFaceModelLoader(string?, string?, bool)

Initializes a new instance of the HuggingFaceModelLoader class.

public HuggingFaceModelLoader(string? cacheDir = null, string? apiToken = null, bool verbose = false)

Parameters

cacheDir string

Directory to cache downloaded models. Uses ~/.cache/huggingface/hub by default.

apiToken string

Optional HuggingFace API token for private models.

verbose bool

Whether to log download progress.

Methods

ClearAllCache()

Clears all cached models.

public void ClearAllCache()

ClearCache(string)

Clears the cache for a specific model.

public void ClearCache(string repoId)

Parameters

repoId string

Repository ID.

DownloadAndLoadAsync(IWeightLoadable<T>, string, string, Func<string, string?>?, string, bool, CancellationToken)

Downloads a model and loads its weights in one operation.

public Task<LoadResult> DownloadAndLoadAsync(IWeightLoadable<T> model, string repoId, string weightsFile = "diffusion_pytorch_model.safetensors", Func<string, string?>? mapping = null, string revision = "main", bool strict = false, CancellationToken cancellationToken = default)

Parameters

model IWeightLoadable<T>

The model to load weights into.

repoId string

HuggingFace repository ID.

weightsFile string

Name of the weights file (e.g., "diffusion_pytorch_model.safetensors").

mapping Func<string, string>

Optional weight mapping.

revision string

Git revision (default: "main").

strict bool

If true, fails when weights can't be loaded.

cancellationToken CancellationToken

Cancellation token.

Returns

Task<LoadResult>

Load result with statistics.

DownloadFileAsync(string, string, string, CancellationToken)

Downloads a single file from HuggingFace Hub.

public Task<string> DownloadFileAsync(string repoId, string revision, string fileName, CancellationToken cancellationToken = default)

Parameters

repoId string

Repository ID.

revision string

Git revision.

fileName string

File name within the repository.

cancellationToken CancellationToken

Cancellation token.

Returns

Task<string>

Local path to the downloaded file.

DownloadModelAsync(string, string, IEnumerable<string>?, CancellationToken)

Downloads a model from HuggingFace Hub.

public Task<Dictionary<string, string>> DownloadModelAsync(string repoId, string revision = "main", IEnumerable<string>? filePatterns = null, CancellationToken cancellationToken = default)

Parameters

repoId string

Repository ID (e.g., "stabilityai/sd-vae-ft-mse").

revision string

Git revision/branch (default: "main").

filePatterns IEnumerable<string>

Optional patterns to filter files (e.g., "*.safetensors").

cancellationToken CancellationToken

Cancellation token.

Returns

Task<Dictionary<string, string>>

Dictionary mapping file names to local paths.

GetCachePath(string, string)

Gets the local cache path for a repository.

public string GetCachePath(string repoId, string revision = "main")

Parameters

repoId string

Repository ID.

revision string

Git revision.

Returns

string

Local cache directory path.

IsCached(string, string, string)

Checks if a model is already cached locally.

public bool IsCached(string repoId, string weightsFile, string revision = "main")

Parameters

repoId string

Repository ID.

weightsFile string

Name of the weights file to check.

revision string

Git revision.

Returns

bool

True if the weights file exists in cache.

LoadWeights(IWeightLoadable<T>, string, WeightMapping, bool)

Loads weights using a WeightMapping instance.

public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, WeightMapping mapping, bool strict = false)

Parameters

model IWeightLoadable<T>
weightsPath string
mapping WeightMapping
strict bool

Returns

LoadResult

LoadWeights(IWeightLoadable<T>, string, Func<string, string?>?, bool)

Loads weights from a downloaded SafeTensors file into a model.

public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, Func<string, string?>? mapping = null, bool strict = false)

Parameters

model IWeightLoadable<T>

The model to load weights into.

weightsPath string

Path to the .safetensors file.

mapping Func<string, string>

Optional weight mapping.

strict bool

If true, fails when weights can't be loaded.

Returns

LoadResult

Load result with statistics.