Interface IWeightLoadable<T>
- Namespace
- AiDotNet.Interfaces
- Assembly
- AiDotNet.dll
Defines the contract for models that support loading weights by name.
public interface IWeightLoadable<T>
Type Parameters
TThe numeric type used for computations.
Remarks
This interface enables loading pretrained weights from external sources like SafeTensors, HuggingFace, and ONNX files into AiDotNet models.
For Beginners: Think of this as a way to "transplant" knowledge from pretrained models. Each weight has a name (like "encoder.conv1.weight") and this interface lets us set those weights by their names.
Example:
// Load pretrained weights
var weights = safeTensorsLoader.Load("model.safetensors");
// Apply to model
if (model is IWeightLoadable<float> loadable)
{
loadable.SetParameter("encoder.conv1.weight", weights["encoder.conv1.weight"]);
}
Properties
NamedParameterCount
Gets the total number of named parameters.
int NamedParameterCount { get; }
Property Value
Methods
GetParameterNames()
Gets all parameter names in this model.
IEnumerable<string> GetParameterNames()
Returns
- IEnumerable<string>
A collection of all parameter names.
Remarks
Parameter names follow a hierarchical convention like: - "encoder.down0.res0.conv1.weight" - "encoder.down0.res0.conv1.bias" - "decoder.up3.norm.gamma"
GetParameterShape(string)
Gets the expected shape for a parameter.
int[]? GetParameterShape(string name)
Parameters
namestringThe parameter name.
Returns
- int[]
The expected shape, or null if the parameter doesn't exist.
LoadWeights(Dictionary<string, Tensor<T>>, Func<string, string?>?, bool)
Loads weights from a dictionary of tensors using optional name mapping.
WeightLoadResult LoadWeights(Dictionary<string, Tensor<T>> weights, Func<string, string?>? mapping = null, bool strict = false)
Parameters
weightsDictionary<string, Tensor<T>>Dictionary of weight name to tensor.
mappingFunc<string, string>Optional function to map source names to target names.
strictboolIf true, throws exception when any mapped weight fails to load.
Returns
- WeightLoadResult
Load result with statistics.
SetParameter(string, Tensor<T>)
Sets a parameter tensor by name.
bool SetParameter(string name, Tensor<T> value)
Parameters
namestringThe parameter name.
valueTensor<T>The tensor value to set.
Returns
- bool
True if the parameter was set successfully, false if the name was not found.
Exceptions
- ArgumentException
Thrown when the tensor shape doesn't match expected shape.
TryGetParameter(string, out Tensor<T>?)
Tries to get a parameter tensor by name.
bool TryGetParameter(string name, out Tensor<T>? tensor)
Parameters
namestringThe parameter name.
tensorTensor<T>The parameter tensor if found.
Returns
- bool
True if the parameter was found, false otherwise.
ValidateWeights(IEnumerable<string>, Func<string, string?>?)
Validates that a set of weight names can be loaded into this model.
WeightLoadValidation ValidateWeights(IEnumerable<string> weightNames, Func<string, string?>? mapping = null)
Parameters
weightNamesIEnumerable<string>Names of weights to validate.
mappingFunc<string, string>Optional weight name mapping.
Returns
- WeightLoadValidation
Validation result with matched and unmatched names.