-
Notifications
You must be signed in to change notification settings - Fork 216
Description
Describe the bug
Unable to load the weights for Wav2Vec2ForSequenceClassification model. Despite meticulously creating a C# class structure that mirrors the official Hugging Face PyTorch implementation, we consistently receive a System.ArgumentException: 'Mismatched state_dict sizes...' error upon calling model.load().
To Reproduce
-
Get the Model: The model is a fine-tuned Wav2Vec2 for voicemail detection from Bland AI, available on Hugging Face: blandai/wav2vec-vm-finetune. The model's config.json confirms its architecture is Wav2Vec2ForSequenceClassification.
-
C# Model Definition: Create the following C# classes, which are designed to be a direct replica of the Hugging Face Wav2Vec2ForSequenceClassification source code.
// File: Wav2Vec2ForSequenceClassificationSharp.cs using TorchSharp; using TorchSharp.Modules; using static TorchSharp.torch; using static TorchSharp.torch.nn; public class Wav2Vec2ForSequenceClassificationSharp : Module<Tensor, Tensor> { // Direct children to match the Python model's structure. private readonly Wav2Vec2Model wav2vec2; private readonly Linear projector; private readonly Linear classifier; public Wav2Vec2ForSequenceClassificationSharp( torchaudio.models.FeatureExtractorNormMode extractor_mode, long[][] extractor_conv_layer_config, bool extractor_conv_bias, int encoder_embed_dim, double encoder_projection_dropout, int encoder_pos_conv_kernel, int encoder_pos_conv_groups, int encoder_num_layers, int encoder_num_heads, double encoder_attention_dropout, int encoder_ff_interm_features, double encoder_ff_interm_dropout, double encoder_dropout, bool encoder_layer_norm_first, double encoder_layer_drop, int classifier_proj_size, int num_labels ) : base(nameof(Wav2Vec2ForSequenceClassificationSharp)) { // The names of these member variables ("wav2vec2", "projector", "classifier") // must exactly match the prefixes in the state_dict. this.wav2vec2 = torchaudio.models.wav2vec2_model( extractor_mode, extractor_conv_layer_config, extractor_conv_bias, encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel, encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads, encoder_attention_dropout, encoder_ff_interm_features, encoder_ff_interm_dropout, encoder_dropout, encoder_layer_norm_first, encoder_layer_drop, aux_num_out: null ); this.projector = Linear(encoder_embed_dim, classifier_proj_size); this.classifier = Linear(classifier_proj_size, num_labels); RegisterComponents(); } public override Tensor forward(Tensor input) { var (hidden_states, _) = this.wav2vec2.forward(input); var projected_states = this.projector.forward(hidden_states); var pooled_output = projected_states.mean(new long[] { 1 }); var logits = this.classifier.forward(pooled_output); return logits; } }
- C# Loading Code: Attempt to load the converted .bin file into an instance of the model.
// File: Program.cs using TorchSharp; using static TorchSharp.torch; public class Program { static void Main(string[] args) { Console.WriteLine("Attempting to load Wav2Vec2ForSequenceClassification model..."); // Model Configuration from config.json var extractor_mode = torchaudio.models.FeatureExtractorNormMode.layer_norm; var extractor_conv_layer_config = new long[][] { new long[] { 512, 10, 5 }, new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 }, new long[] { 512, 2, 2 }, new long[] { 512, 2, 2 } }; var extractor_conv_bias = true; var encoder_embed_dim = 1024; var encoder_projection_dropout = 0.1; var encoder_pos_conv_kernel = 128; var encoder_pos_conv_groups = 16; var encoder_num_layers = 24; var encoder_num_heads = 16; var encoder_attention_dropout = 0.1; var encoder_ff_interm_features = 4096; var encoder_ff_interm_dropout = 0.0; var encoder_dropout = 0.1; var encoder_layer_norm_first = true; var encoder_layer_drop = 0.1; var classifier_proj_size = 256; var num_labels = 2; try { var model = new Wav2Vec2ForSequenceClassificationSharp( extractor_mode, extractor_conv_layer_config, extractor_conv_bias, encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel, encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads, encoder_attention_dropout, encoder_ff_interm_features, encoder_ff_interm_dropout, encoder_dropout, encoder_layer_norm_first, encoder_layer_drop, classifier_proj_size, num_labels ); var modelPath = "path/to/your/converted_model.bin"; // This is where the exception is thrown. model.load(modelPath, strict: true); Console.WriteLine("Model loaded successfully!"); } catch (Exception ex) { Console.WriteLine("\nERROR: Failed to load model."); Console.WriteLine(ex.ToString()); } } }
Enviornment
- TorchSharp Version: TorchSharp-cpu v0.101.5 (or your version)
- .NET Version: .NET 9.0
- Operating System: Windows 10 x64
- CPU/GPU: CPU-only
Thanks for your contribution! ⭐