Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Failed to load Wav2Vec2ForSequenceClassification model #1487

Open
Labels
bugSomething isn't working
@abdofallah

Description

Describe the bug
Unable to load the weights for Wav2Vec2ForSequenceClassification model. Despite meticulously creating a C# class structure that mirrors the official Hugging Face PyTorch implementation, we consistently receive a System.ArgumentException: 'Mismatched state_dict sizes...' error upon calling model.load().

To Reproduce

  1. Get the Model: The model is a fine-tuned Wav2Vec2 for voicemail detection from Bland AI, available on Hugging Face: blandai/wav2vec-vm-finetune. The model's config.json confirms its architecture is Wav2Vec2ForSequenceClassification.

  2. C# Model Definition: Create the following C# classes, which are designed to be a direct replica of the Hugging Face Wav2Vec2ForSequenceClassification source code.

// File: Wav2Vec2ForSequenceClassificationSharp.cs
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
public class Wav2Vec2ForSequenceClassificationSharp : Module<Tensor, Tensor>
{
 // Direct children to match the Python model's structure.
 private readonly Wav2Vec2Model wav2vec2;
 private readonly Linear projector;
 private readonly Linear classifier;
 public Wav2Vec2ForSequenceClassificationSharp(
 torchaudio.models.FeatureExtractorNormMode extractor_mode,
 long[][] extractor_conv_layer_config, bool extractor_conv_bias,
 int encoder_embed_dim, double encoder_projection_dropout,
 int encoder_pos_conv_kernel, int encoder_pos_conv_groups,
 int encoder_num_layers, int encoder_num_heads,
 double encoder_attention_dropout, int encoder_ff_interm_features,
 double encoder_ff_interm_dropout, double encoder_dropout,
 bool encoder_layer_norm_first, double encoder_layer_drop,
 int classifier_proj_size, int num_labels
 ) : base(nameof(Wav2Vec2ForSequenceClassificationSharp))
 {
 // The names of these member variables ("wav2vec2", "projector", "classifier")
 // must exactly match the prefixes in the state_dict.
 this.wav2vec2 = torchaudio.models.wav2vec2_model(
 extractor_mode, extractor_conv_layer_config, extractor_conv_bias,
 encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel,
 encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads,
 encoder_attention_dropout, encoder_ff_interm_features, encoder_ff_interm_dropout,
 encoder_dropout, encoder_layer_norm_first, encoder_layer_drop,
 aux_num_out: null
 );
 
 this.projector = Linear(encoder_embed_dim, classifier_proj_size);
 this.classifier = Linear(classifier_proj_size, num_labels);
 RegisterComponents();
 }
 public override Tensor forward(Tensor input)
 {
 var (hidden_states, _) = this.wav2vec2.forward(input);
 var projected_states = this.projector.forward(hidden_states);
 var pooled_output = projected_states.mean(new long[] { 1 });
 var logits = this.classifier.forward(pooled_output);
 return logits;
 }
}
  1. C# Loading Code: Attempt to load the converted .bin file into an instance of the model.
// File: Program.cs
using TorchSharp;
using static TorchSharp.torch;
public class Program
{
 static void Main(string[] args)
 {
 Console.WriteLine("Attempting to load Wav2Vec2ForSequenceClassification model...");
 // Model Configuration from config.json
 var extractor_mode = torchaudio.models.FeatureExtractorNormMode.layer_norm;
 var extractor_conv_layer_config = new long[][] {
 new long[] { 512, 10, 5 }, new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 },
 new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 }, new long[] { 512, 2, 2 },
 new long[] { 512, 2, 2 }
 };
 var extractor_conv_bias = true;
 var encoder_embed_dim = 1024;
 var encoder_projection_dropout = 0.1;
 var encoder_pos_conv_kernel = 128;
 var encoder_pos_conv_groups = 16;
 var encoder_num_layers = 24;
 var encoder_num_heads = 16;
 var encoder_attention_dropout = 0.1;
 var encoder_ff_interm_features = 4096;
 var encoder_ff_interm_dropout = 0.0;
 var encoder_dropout = 0.1;
 var encoder_layer_norm_first = true;
 var encoder_layer_drop = 0.1;
 var classifier_proj_size = 256;
 var num_labels = 2;
 try
 {
 var model = new Wav2Vec2ForSequenceClassificationSharp(
 extractor_mode, extractor_conv_layer_config, extractor_conv_bias,
 encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel,
 encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads,
 encoder_attention_dropout, encoder_ff_interm_features,
 encoder_ff_interm_dropout, encoder_dropout,
 encoder_layer_norm_first, encoder_layer_drop,
 classifier_proj_size, num_labels
 );
 var modelPath = "path/to/your/converted_model.bin";
 
 // This is where the exception is thrown.
 model.load(modelPath, strict: true); 
 Console.WriteLine("Model loaded successfully!");
 }
 catch (Exception ex)
 {
 Console.WriteLine("\nERROR: Failed to load model.");
 Console.WriteLine(ex.ToString());
 }
 }
}

Enviornment

  • TorchSharp Version: TorchSharp-cpu v0.101.5 (or your version)
  • .NET Version: .NET 9.0
  • Operating System: Windows 10 x64
  • CPU/GPU: CPU-only

Thanks for your contribution! ⭐

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

      Relationships

      None yet

      Development

      No branches or pull requests

      Issue actions

        AltStyle によって変換されたページ (->オリジナル) /