Skip to content

Commit 4483bb6

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add Wav2Vec2 base model (#4513)
Summary: Pull Request resolved: #4513 As titled. Reviewed By: zonglinpengmeta Differential Revision: D60619295 fbshipit-source-id: 00fd48029bc2413cf2a4a1453c80bbf65d29c57f
1 parent 1090bcd commit 4483bb6

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

examples/cadence/models/wav2vec2.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
12+
13+
import torch
14+
15+
from executorch.backends.cadence.aot.export_example import export_model
16+
from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model
17+
18+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
19+
logging.basicConfig(level=logging.INFO, format=FORMAT)
20+
21+
22+
def main() -> None:
23+
# The wrapper is needed to avoid issues with the optional second arguments
24+
# of Wav2Vec2Models.
25+
class Wav2Vec2ModelWrapper(torch.nn.Module):
26+
def __init__(self, model: Wav2Vec2Model):
27+
super().__init__()
28+
self.model = model
29+
30+
def forward(self, x):
31+
out, _ = self.model(x)
32+
return out
33+
34+
_model = wav2vec2_model(
35+
extractor_mode="layer_norm",
36+
extractor_conv_layer_config=None,
37+
extractor_conv_bias=False,
38+
encoder_embed_dim=768,
39+
encoder_projection_dropout=0.1,
40+
encoder_pos_conv_kernel=128,
41+
encoder_pos_conv_groups=16,
42+
encoder_num_layers=12,
43+
encoder_num_heads=12,
44+
encoder_attention_dropout=0.1,
45+
encoder_ff_interm_features=3072,
46+
encoder_ff_interm_dropout=0.0,
47+
encoder_dropout=0.1,
48+
encoder_layer_norm_first=False,
49+
encoder_layer_drop=0.1,
50+
aux_num_out=None,
51+
)
52+
_model.eval()
53+
54+
model = Wav2Vec2ModelWrapper(_model)
55+
model.eval()
56+
57+
# test input
58+
audio_len = 1680
59+
example_inputs = (torch.rand(1, audio_len),)
60+
61+
export_model(model, example_inputs)
62+
63+
64+
if __name__ == "__main__":
65+
main()

0 commit comments

Comments
 (0)