Class Wav2Vec2Model<T>
- Namespace
- AiDotNet.Audio.SpeechRecognition
- Assembly
- AiDotNet.dll
Wav2Vec2 self-supervised speech recognition model.
public class Wav2Vec2Model<T> : AudioNeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, ISpeechRecognizer<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
Wav2Vec2Model<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Wav2Vec2 is a self-supervised learning model for speech recognition developed by Meta AI. It learns representations from raw audio through contrastive learning, then can be fine-tuned for speech recognition tasks.
For Beginners: Wav2Vec2 works differently from traditional speech recognition:
- It processes raw audio directly (no mel spectrograms needed)
- It learns speech patterns from unlabeled audio data
- It can be fine-tuned with small amounts of labeled data
Architecture:
- Convolutional feature encoder: Processes raw audio into features
- Transformer encoder: Captures long-range dependencies in speech
- CTC head: Aligns speech to text (Connectionist Temporal Classification)
Two ways to use this class:
- ONNX Mode: Load pretrained Wav2Vec2 models for fast inference
- Native Mode: Train your own speech recognition model from scratch
ONNX Mode Example:
var wav2vec2 = new Wav2Vec2Model<float>(
architecture,
modelPath: "path/to/wav2vec2.onnx");
var result = wav2vec2.Transcribe(audioTensor);
Console.WriteLine(result.Text);
Training Mode Example:
var wav2vec2 = new Wav2Vec2Model<float>(architecture);
for (int epoch = 0; epoch < 100; epoch++)
{
foreach (var (audio, tokens) in trainingData)
{
wav2vec2.Train(audio, tokens);
}
}
Constructors
Wav2Vec2Model(NeuralNetworkArchitecture<T>, string?, int, int, int, int, int, int, string[]?, IOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Creates a Wav2Vec2 network for training from scratch using native layers.
public Wav2Vec2Model(NeuralNetworkArchitecture<T> architecture, string? language = "en", int sampleRate = 16000, int maxAudioLengthSeconds = 30, int hiddenDim = 768, int numTransformerLayers = 12, int numHeads = 12, int ffDim = 3072, string[]? vocabulary = null, IOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture configuration.
languagestringTarget language code (e.g., "en", "es"). Default is "en".
sampleRateintAudio sample rate in Hz. Default is 16000.
maxAudioLengthSecondsintMaximum audio length to process. Default is 30 seconds.
hiddenDimintHidden dimension for transformer. Default is 768.
numTransformerLayersintNumber of transformer layers. Default is 12.
numHeadsintNumber of attention heads. Default is 12.
ffDimintFeed-forward dimension. Default is 3072.
vocabularystring[]CTC vocabulary for decoding. If null, uses default English alphabet.
optimizerIOptimizer<T, Tensor<T>, Tensor<T>>Optimizer for training. If null, uses Adam with default settings.
lossFunctionILossFunction<T>Loss function for training. If null, uses CTC loss.
Remarks
For Beginners: Use this constructor to train a speech recognition model from scratch.
Training Wav2Vec2 typically involves:
- Pre-training on unlabeled audio (self-supervised)
- Fine-tuning on labeled transcription data
Example:
var wav2vec2 = new Wav2Vec2Model<float>(
architecture,
language: "en",
hiddenDim: 768,
numTransformerLayers: 12);
// Training loop
for (int epoch = 0; epoch < numEpochs; epoch++)
{
foreach (var (audio, tokens) in trainingData)
{
wav2vec2.Train(audio, tokens);
}
}
Wav2Vec2Model(NeuralNetworkArchitecture<T>, string, string?, int, int, string[]?, OnnxModelOptions?)
Creates a Wav2Vec2 network using a pretrained ONNX model.
public Wav2Vec2Model(NeuralNetworkArchitecture<T> architecture, string modelPath, string? language = "en", int sampleRate = 16000, int maxAudioLengthSeconds = 30, string[]? vocabulary = null, OnnxModelOptions? onnxOptions = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture configuration.
modelPathstringPath to the ONNX model file.
languagestringTarget language code (e.g., "en", "es"). Default is "en".
sampleRateintAudio sample rate in Hz. Wav2Vec2 expects 16000.
maxAudioLengthSecondsintMaximum audio length to process. Default is 30 seconds.
vocabularystring[]CTC vocabulary for decoding. If null, uses default English alphabet.
onnxOptionsOnnxModelOptionsONNX runtime options.
Remarks
For Beginners: Use this constructor when you have a pretrained Wav2Vec2 ONNX model.
You can get ONNX models from:
- HuggingFace: facebook/wav2vec2-base-960h, etc.
- Convert from PyTorch using ONNX export tools
Example:
var wav2vec2 = new Wav2Vec2Model<float>(
architecture,
modelPath: "wav2vec2-base.onnx",
language: "en");
Properties
IsReady
Gets whether the model is ready for inference.
public bool IsReady { get; }
Property Value
Language
Gets the target language for transcription.
public string? Language { get; }
Property Value
MaxAudioLengthSeconds
Gets the maximum audio length in seconds.
public int MaxAudioLengthSeconds { get; }
Property Value
SupportedLanguages
Gets the list of languages supported by this model.
public IReadOnlyList<string> SupportedLanguages { get; }
Property Value
SupportsStreaming
Gets whether this model supports real-time streaming transcription.
public bool SupportsStreaming { get; }
Property Value
SupportsWordTimestamps
Gets whether this model can identify timestamps for each word.
public bool SupportsWordTimestamps { get; }
Property Value
Methods
CreateNewInstance()
Creates a new instance of this model for cloning.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReader
DetectLanguage(Tensor<T>)
Detects the language spoken in the audio.
public string DetectLanguage(Tensor<T> audio)
Parameters
audioTensor<T>
Returns
DetectLanguageProbabilities(Tensor<T>)
Gets language detection probabilities for the audio.
public IReadOnlyDictionary<string, T> DetectLanguageProbabilities(Tensor<T> audio)
Parameters
audioTensor<T>
Returns
Dispose(bool)
Disposes the model and releases resources.
protected override void Dispose(bool disposing)
Parameters
disposingbool
GetModelMetadata()
Gets metadata about the model.
public override ModelMetadata<T> GetModelMetadata()
Returns
InitializeLayers()
Initializes layers for ONNX inference mode.
protected override void InitializeLayers()
PostprocessOutput(Tensor<T>)
Postprocesses model output.
protected override Tensor<T> PostprocessOutput(Tensor<T> modelOutput)
Parameters
modelOutputTensor<T>
Returns
- Tensor<T>
Predict(Tensor<T>)
Makes a prediction using the model.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>
Returns
- Tensor<T>
PreprocessAudio(Tensor<T>)
Preprocesses raw audio for model input.
protected override Tensor<T> PreprocessAudio(Tensor<T> rawAudio)
Parameters
rawAudioTensor<T>
Returns
- Tensor<T>
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriter
StartStreamingSession(string?)
Starts a streaming transcription session.
public IStreamingTranscriptionSession<T> StartStreamingSession(string? language = null)
Parameters
languagestring
Returns
Train(Tensor<T>, Tensor<T>)
Trains the model on a single batch.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>expectedOutputTensor<T>
Transcribe(Tensor<T>, string?, bool)
Transcribes audio to text.
public TranscriptionResult<T> Transcribe(Tensor<T> audio, string? language = null, bool includeTimestamps = false)
Parameters
Returns
TranscribeAsync(Tensor<T>, string?, bool, CancellationToken)
Transcribes audio to text asynchronously.
public Task<TranscriptionResult<T>> TranscribeAsync(Tensor<T> audio, string? language = null, bool includeTimestamps = false, CancellationToken cancellationToken = default)
Parameters
audioTensor<T>languagestringincludeTimestampsboolcancellationTokenCancellationToken
Returns
UpdateParameters(Vector<T>)
Updates model parameters by applying gradient descent.
public override void UpdateParameters(Vector<T> gradients)
Parameters
gradientsVector<T>