Table of Contents

Class Wav2Vec2Model<T>

Namespace
AiDotNet.Audio.SpeechRecognition
Assembly
AiDotNet.dll

Wav2Vec2 self-supervised speech recognition model.

public class Wav2Vec2Model<T> : AudioNeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, ISpeechRecognizer<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type used for calculations.

Inheritance
Wav2Vec2Model<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

Wav2Vec2 is a self-supervised learning model for speech recognition developed by Meta AI. It learns representations from raw audio through contrastive learning, then can be fine-tuned for speech recognition tasks.

For Beginners: Wav2Vec2 works differently from traditional speech recognition:

  1. It processes raw audio directly (no mel spectrograms needed)
  2. It learns speech patterns from unlabeled audio data
  3. It can be fine-tuned with small amounts of labeled data

Architecture:

  • Convolutional feature encoder: Processes raw audio into features
  • Transformer encoder: Captures long-range dependencies in speech
  • CTC head: Aligns speech to text (Connectionist Temporal Classification)

Two ways to use this class:

  1. ONNX Mode: Load pretrained Wav2Vec2 models for fast inference
  2. Native Mode: Train your own speech recognition model from scratch

ONNX Mode Example:

var wav2vec2 = new Wav2Vec2Model<float>(
    architecture,
    modelPath: "path/to/wav2vec2.onnx");
var result = wav2vec2.Transcribe(audioTensor);
Console.WriteLine(result.Text);

Training Mode Example:

var wav2vec2 = new Wav2Vec2Model<float>(architecture);
for (int epoch = 0; epoch < 100; epoch++)
{
    foreach (var (audio, tokens) in trainingData)
    {
        wav2vec2.Train(audio, tokens);
    }
}

Constructors

Wav2Vec2Model(NeuralNetworkArchitecture<T>, string?, int, int, int, int, int, int, string[]?, IOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)

Creates a Wav2Vec2 network for training from scratch using native layers.

public Wav2Vec2Model(NeuralNetworkArchitecture<T> architecture, string? language = "en", int sampleRate = 16000, int maxAudioLengthSeconds = 30, int hiddenDim = 768, int numTransformerLayers = 12, int numHeads = 12, int ffDim = 3072, string[]? vocabulary = null, IOptimizer<T, Tensor<T>, Tensor<T>>? optimizer = null, ILossFunction<T>? lossFunction = null)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture configuration.

language string

Target language code (e.g., "en", "es"). Default is "en".

sampleRate int

Audio sample rate in Hz. Default is 16000.

maxAudioLengthSeconds int

Maximum audio length to process. Default is 30 seconds.

hiddenDim int

Hidden dimension for transformer. Default is 768.

numTransformerLayers int

Number of transformer layers. Default is 12.

numHeads int

Number of attention heads. Default is 12.

ffDim int

Feed-forward dimension. Default is 3072.

vocabulary string[]

CTC vocabulary for decoding. If null, uses default English alphabet.

optimizer IOptimizer<T, Tensor<T>, Tensor<T>>

Optimizer for training. If null, uses Adam with default settings.

lossFunction ILossFunction<T>

Loss function for training. If null, uses CTC loss.

Remarks

For Beginners: Use this constructor to train a speech recognition model from scratch.

Training Wav2Vec2 typically involves:

  1. Pre-training on unlabeled audio (self-supervised)
  2. Fine-tuning on labeled transcription data

Example:

var wav2vec2 = new Wav2Vec2Model<float>(
    architecture,
    language: "en",
    hiddenDim: 768,
    numTransformerLayers: 12);

// Training loop
for (int epoch = 0; epoch < numEpochs; epoch++)
{
    foreach (var (audio, tokens) in trainingData)
    {
        wav2vec2.Train(audio, tokens);
    }
}

Wav2Vec2Model(NeuralNetworkArchitecture<T>, string, string?, int, int, string[]?, OnnxModelOptions?)

Creates a Wav2Vec2 network using a pretrained ONNX model.

public Wav2Vec2Model(NeuralNetworkArchitecture<T> architecture, string modelPath, string? language = "en", int sampleRate = 16000, int maxAudioLengthSeconds = 30, string[]? vocabulary = null, OnnxModelOptions? onnxOptions = null)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture configuration.

modelPath string

Path to the ONNX model file.

language string

Target language code (e.g., "en", "es"). Default is "en".

sampleRate int

Audio sample rate in Hz. Wav2Vec2 expects 16000.

maxAudioLengthSeconds int

Maximum audio length to process. Default is 30 seconds.

vocabulary string[]

CTC vocabulary for decoding. If null, uses default English alphabet.

onnxOptions OnnxModelOptions

ONNX runtime options.

Remarks

For Beginners: Use this constructor when you have a pretrained Wav2Vec2 ONNX model.

You can get ONNX models from:

  • HuggingFace: facebook/wav2vec2-base-960h, etc.
  • Convert from PyTorch using ONNX export tools

Example:

var wav2vec2 = new Wav2Vec2Model<float>(
    architecture,
    modelPath: "wav2vec2-base.onnx",
    language: "en");

Properties

IsReady

Gets whether the model is ready for inference.

public bool IsReady { get; }

Property Value

bool

Language

Gets the target language for transcription.

public string? Language { get; }

Property Value

string

MaxAudioLengthSeconds

Gets the maximum audio length in seconds.

public int MaxAudioLengthSeconds { get; }

Property Value

int

SupportedLanguages

Gets the list of languages supported by this model.

public IReadOnlyList<string> SupportedLanguages { get; }

Property Value

IReadOnlyList<string>

SupportsStreaming

Gets whether this model supports real-time streaming transcription.

public bool SupportsStreaming { get; }

Property Value

bool

SupportsWordTimestamps

Gets whether this model can identify timestamps for each word.

public bool SupportsWordTimestamps { get; }

Property Value

bool

Methods

CreateNewInstance()

Creates a new instance of this model for cloning.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

DetectLanguage(Tensor<T>)

Detects the language spoken in the audio.

public string DetectLanguage(Tensor<T> audio)

Parameters

audio Tensor<T>

Returns

string

DetectLanguageProbabilities(Tensor<T>)

Gets language detection probabilities for the audio.

public IReadOnlyDictionary<string, T> DetectLanguageProbabilities(Tensor<T> audio)

Parameters

audio Tensor<T>

Returns

IReadOnlyDictionary<string, T>

Dispose(bool)

Disposes the model and releases resources.

protected override void Dispose(bool disposing)

Parameters

disposing bool

GetModelMetadata()

Gets metadata about the model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

InitializeLayers()

Initializes layers for ONNX inference mode.

protected override void InitializeLayers()

PostprocessOutput(Tensor<T>)

Postprocesses model output.

protected override Tensor<T> PostprocessOutput(Tensor<T> modelOutput)

Parameters

modelOutput Tensor<T>

Returns

Tensor<T>

Predict(Tensor<T>)

Makes a prediction using the model.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

Returns

Tensor<T>

PreprocessAudio(Tensor<T>)

Preprocesses raw audio for model input.

protected override Tensor<T> PreprocessAudio(Tensor<T> rawAudio)

Parameters

rawAudio Tensor<T>

Returns

Tensor<T>

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

StartStreamingSession(string?)

Starts a streaming transcription session.

public IStreamingTranscriptionSession<T> StartStreamingSession(string? language = null)

Parameters

language string

Returns

IStreamingTranscriptionSession<T>

Train(Tensor<T>, Tensor<T>)

Trains the model on a single batch.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>
expectedOutput Tensor<T>

Transcribe(Tensor<T>, string?, bool)

Transcribes audio to text.

public TranscriptionResult<T> Transcribe(Tensor<T> audio, string? language = null, bool includeTimestamps = false)

Parameters

audio Tensor<T>
language string
includeTimestamps bool

Returns

TranscriptionResult<T>

TranscribeAsync(Tensor<T>, string?, bool, CancellationToken)

Transcribes audio to text asynchronously.

public Task<TranscriptionResult<T>> TranscribeAsync(Tensor<T> audio, string? language = null, bool includeTimestamps = false, CancellationToken cancellationToken = default)

Parameters

audio Tensor<T>
language string
includeTimestamps bool
cancellationToken CancellationToken

Returns

Task<TranscriptionResult<T>>

UpdateParameters(Vector<T>)

Updates model parameters by applying gradient descent.

public override void UpdateParameters(Vector<T> gradients)

Parameters

gradients Vector<T>