Class HuggingFaceModelLoader<T>
- Namespace
- AiDotNet.ModelLoading
- Assembly
- AiDotNet.dll
Downloads and caches models from HuggingFace Hub.
public class HuggingFaceModelLoader<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
HuggingFaceModelLoader<T>
- Inherited Members
Remarks
For Beginners: HuggingFace Hub is like a library of pretrained AI models.
Instead of training models yourself (which requires huge amounts of data and compute), you can download models that others have already trained. This class handles:
- Downloading model files from HuggingFace
- Caching them locally so you don't re-download every time
- Loading the weights into your model
Example usage:
var loader = new HuggingFaceModelLoader<float>();
// Download and cache a pretrained VAE
var files = await loader.DownloadModelAsync("stabilityai/sd-vae-ft-mse");
// Load weights into your model
var vae = new VAEEncoder<float>();
loader.LoadWeights(vae, files["diffusion_pytorch_model.safetensors"]);
Constructors
HuggingFaceModelLoader(string?, string?, bool)
Initializes a new instance of the HuggingFaceModelLoader class.
public HuggingFaceModelLoader(string? cacheDir = null, string? apiToken = null, bool verbose = false)
Parameters
cacheDirstringDirectory to cache downloaded models. Uses ~/.cache/huggingface/hub by default.
apiTokenstringOptional HuggingFace API token for private models.
verboseboolWhether to log download progress.
Methods
ClearAllCache()
Clears all cached models.
public void ClearAllCache()
ClearCache(string)
Clears the cache for a specific model.
public void ClearCache(string repoId)
Parameters
repoIdstringRepository ID.
DownloadAndLoadAsync(IWeightLoadable<T>, string, string, Func<string, string?>?, string, bool, CancellationToken)
Downloads a model and loads its weights in one operation.
public Task<LoadResult> DownloadAndLoadAsync(IWeightLoadable<T> model, string repoId, string weightsFile = "diffusion_pytorch_model.safetensors", Func<string, string?>? mapping = null, string revision = "main", bool strict = false, CancellationToken cancellationToken = default)
Parameters
modelIWeightLoadable<T>The model to load weights into.
repoIdstringHuggingFace repository ID.
weightsFilestringName of the weights file (e.g., "diffusion_pytorch_model.safetensors").
mappingFunc<string, string>Optional weight mapping.
revisionstringGit revision (default: "main").
strictboolIf true, fails when weights can't be loaded.
cancellationTokenCancellationTokenCancellation token.
Returns
- Task<LoadResult>
Load result with statistics.
DownloadFileAsync(string, string, string, CancellationToken)
Downloads a single file from HuggingFace Hub.
public Task<string> DownloadFileAsync(string repoId, string revision, string fileName, CancellationToken cancellationToken = default)
Parameters
repoIdstringRepository ID.
revisionstringGit revision.
fileNamestringFile name within the repository.
cancellationTokenCancellationTokenCancellation token.
Returns
DownloadModelAsync(string, string, IEnumerable<string>?, CancellationToken)
Downloads a model from HuggingFace Hub.
public Task<Dictionary<string, string>> DownloadModelAsync(string repoId, string revision = "main", IEnumerable<string>? filePatterns = null, CancellationToken cancellationToken = default)
Parameters
repoIdstringRepository ID (e.g., "stabilityai/sd-vae-ft-mse").
revisionstringGit revision/branch (default: "main").
filePatternsIEnumerable<string>Optional patterns to filter files (e.g., "*.safetensors").
cancellationTokenCancellationTokenCancellation token.
Returns
- Task<Dictionary<string, string>>
Dictionary mapping file names to local paths.
GetCachePath(string, string)
Gets the local cache path for a repository.
public string GetCachePath(string repoId, string revision = "main")
Parameters
Returns
- string
Local cache directory path.
IsCached(string, string, string)
Checks if a model is already cached locally.
public bool IsCached(string repoId, string weightsFile, string revision = "main")
Parameters
repoIdstringRepository ID.
weightsFilestringName of the weights file to check.
revisionstringGit revision.
Returns
- bool
True if the weights file exists in cache.
LoadWeights(IWeightLoadable<T>, string, WeightMapping, bool)
Loads weights using a WeightMapping instance.
public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, WeightMapping mapping, bool strict = false)
Parameters
modelIWeightLoadable<T>weightsPathstringmappingWeightMappingstrictbool
Returns
LoadWeights(IWeightLoadable<T>, string, Func<string, string?>?, bool)
Loads weights from a downloaded SafeTensors file into a model.
public LoadResult LoadWeights(IWeightLoadable<T> model, string weightsPath, Func<string, string?>? mapping = null, bool strict = false)
Parameters
modelIWeightLoadable<T>The model to load weights into.
weightsPathstringPath to the .safetensors file.
mappingFunc<string, string>Optional weight mapping.
strictboolIf true, fails when weights can't be loaded.
Returns
- LoadResult
Load result with statistics.