Skip to content

Commit 7895982

Browse files
authored
Qualcomm AI Engine Direct - wav2letter e2e example
Differential Revision: D65734745 Pull Request resolved: #5924
1 parent bec0625 commit 7895982

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,6 +2918,44 @@ def test_ptq_mobilebert(self):
29182918
for k, v in cpu.items():
29192919
self.assertLessEqual(abs(v[0] - htp[k][0]), 5)
29202920

2921+
def test_wav2letter(self):
2922+
if not self.required_envs([self.pretrained_weight]):
2923+
self.skipTest("missing required envs")
2924+
2925+
cmds = [
2926+
"python",
2927+
f"{self.executorch_root}/examples/qualcomm/scripts/wav2letter.py",
2928+
"--artifact",
2929+
self.artifact_dir,
2930+
"--build_folder",
2931+
self.build_folder,
2932+
"--device",
2933+
self.device,
2934+
"--model",
2935+
self.model,
2936+
"--pretrained_weight",
2937+
self.pretrained_weight,
2938+
"--ip",
2939+
self.ip,
2940+
"--port",
2941+
str(self.port),
2942+
]
2943+
if self.host:
2944+
cmds.extend(["--host", self.host])
2945+
if self.shared_buffer:
2946+
cmds.extend(["--shared_buffer"])
2947+
2948+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
2949+
with Listener((self.ip, self.port)) as listener:
2950+
conn = listener.accept()
2951+
p.communicate()
2952+
msg = json.loads(conn.recv())
2953+
if "Error" in msg:
2954+
self.fail(msg["Error"])
2955+
else:
2956+
self.assertLessEqual(msg["wer"], 0.5)
2957+
self.assertLessEqual(msg["cer"], 0.25)
2958+
29212959
def test_export_example(self):
29222960
if not self.required_envs([self.model_name]):
29232961
self.skipTest("missing required envs")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pip install soundfile
2+
pip install torchmetrics
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import os
9+
import sys
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
14+
import torch
15+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16+
from executorch.examples.models.wav2letter import Wav2LetterModel
17+
from executorch.examples.qualcomm.utils import (
18+
build_executorch_binary,
19+
make_output_dir,
20+
parse_skip_delegation_node,
21+
setup_common_args_and_variables,
22+
SimpleADB,
23+
)
24+
25+
26+
class Conv2D(torch.nn.Module):
27+
def __init__(self, stride, padding, weight, bias=None):
28+
super().__init__()
29+
use_bias = bias is not None
30+
self.conv = torch.nn.Conv2d(
31+
in_channels=weight.shape[1],
32+
out_channels=weight.shape[0],
33+
kernel_size=[weight.shape[2], 1],
34+
stride=[*stride, 1],
35+
padding=[*padding, 0],
36+
bias=use_bias,
37+
)
38+
self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1))
39+
if use_bias:
40+
self.conv.bias = torch.nn.Parameter(bias)
41+
42+
def forward(self, x):
43+
return self.conv(x)
44+
45+
46+
def get_dataset(data_size, artifact_dir):
47+
from torch.utils.data import DataLoader
48+
from torchaudio.datasets import LIBRISPEECH
49+
50+
def collate_fun(batch):
51+
waves, labels = [], []
52+
53+
for wave, _, text, *_ in batch:
54+
waves.append(wave.squeeze(0))
55+
labels.append(text)
56+
# need padding here for static ouput shape
57+
waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True)
58+
return waves, labels
59+
60+
dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True)
61+
data_loader = DataLoader(
62+
dataset=dataset,
63+
batch_size=data_size,
64+
shuffle=True,
65+
collate_fn=lambda x: collate_fun(x),
66+
)
67+
# prepare input data
68+
inputs, targets, input_list = [], [], ""
69+
for wave, label in data_loader:
70+
for index in range(data_size):
71+
# reshape input tensor to NCHW
72+
inputs.append((wave[index].reshape(1, 1, -1, 1),))
73+
targets.append(label[index])
74+
input_list += f"input_{index}_0.raw\n"
75+
# here we only take first batch, i.e. 'data_size' tensors
76+
break
77+
78+
return inputs, targets, input_list
79+
80+
81+
def eval_metric(pred, target_str):
82+
from torchmetrics.text import CharErrorRate, WordErrorRate
83+
84+
def parse(ids):
85+
vocab = " abcdefghijklmnopqrstuvwxyz'*"
86+
return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids]
87+
88+
pred_str = parse(
89+
[
90+
torch.unique_consecutive(pred[i, :, :].argmax(0))
91+
for i in range(pred.shape[0])
92+
]
93+
)
94+
wer, cer = WordErrorRate(), CharErrorRate()
95+
return wer(pred_str, target_str), cer(pred_str, target_str)
96+
97+
98+
def main(args):
99+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
100+
101+
# ensure the working directory exist
102+
os.makedirs(args.artifact, exist_ok=True)
103+
104+
if not args.compile_only and args.device is None:
105+
raise RuntimeError(
106+
"device serial is required if not compile only. "
107+
"Please specify a device serial by -s/--device argument."
108+
)
109+
110+
instance = Wav2LetterModel()
111+
# target labels " abcdefghijklmnopqrstuvwxyz'*"
112+
instance.vocab_size = 29
113+
model = instance.get_eager_model().eval()
114+
model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True))
115+
116+
# convert conv1d to conv2d in nn.Module level will only introduce 2 permute
117+
# nodes around input & output, which is more quantization friendly.
118+
for i in range(len(model.acoustic_model)):
119+
for j in range(len(model.acoustic_model[i])):
120+
module = model.acoustic_model[i][j]
121+
if isinstance(module, torch.nn.Conv1d):
122+
model.acoustic_model[i][j] = Conv2D(
123+
stride=module.stride,
124+
padding=module.padding,
125+
weight=module.weight,
126+
bias=module.bias,
127+
)
128+
129+
# retrieve dataset, will take some time to download
130+
data_num = 100
131+
inputs, targets, input_list = get_dataset(
132+
data_size=data_num, artifact_dir=args.artifact
133+
)
134+
pte_filename = "w2l_qnn"
135+
build_executorch_binary(
136+
model,
137+
inputs[0],
138+
args.model,
139+
f"{args.artifact}/{pte_filename}",
140+
inputs,
141+
skip_node_id_set=skip_node_id_set,
142+
skip_node_op_set=skip_node_op_set,
143+
quant_dtype=QuantDtype.use_8a8w,
144+
shared_buffer=args.shared_buffer,
145+
)
146+
147+
if args.compile_only:
148+
sys.exit(0)
149+
150+
adb = SimpleADB(
151+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
152+
build_path=f"{args.build_folder}",
153+
pte_path=f"{args.artifact}/{pte_filename}.pte",
154+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
155+
device_id=args.device,
156+
host_id=args.host,
157+
soc_model=args.model,
158+
shared_buffer=args.shared_buffer,
159+
)
160+
adb.push(inputs=inputs, input_list=input_list)
161+
adb.execute()
162+
163+
# collect output data
164+
output_data_folder = f"{args.artifact}/outputs"
165+
make_output_dir(output_data_folder)
166+
adb.pull(output_path=args.artifact)
167+
168+
predictions = []
169+
for i in range(data_num):
170+
predictions.append(
171+
np.fromfile(
172+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
173+
)
174+
)
175+
176+
# evaluate metrics
177+
wer, cer = 0, 0
178+
for i, pred in enumerate(predictions):
179+
pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1)
180+
wer_eval, cer_eval = eval_metric(pred, targets[i])
181+
wer += wer_eval
182+
cer += cer_eval
183+
184+
if args.ip and args.port != -1:
185+
with Client((args.ip, args.port)) as conn:
186+
conn.send(
187+
json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num})
188+
)
189+
else:
190+
print(f"wer: {wer / data_num}\ncer: {cer / data_num}")
191+
192+
193+
if __name__ == "__main__":
194+
parser = setup_common_args_and_variables()
195+
196+
parser.add_argument(
197+
"-a",
198+
"--artifact",
199+
help="path for storing generated artifacts by this example. "
200+
"Default ./wav2letter",
201+
default="./wav2letter",
202+
type=str,
203+
)
204+
205+
parser.add_argument(
206+
"-p",
207+
"--pretrained_weight",
208+
help=(
209+
"Location of pretrained weight, please download via "
210+
"https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch"
211+
" for torchaudio.models.Wav2Letter version"
212+
),
213+
default=None,
214+
type=str,
215+
required=True,
216+
)
217+
218+
args = parser.parse_args()
219+
try:
220+
main(args)
221+
except Exception as e:
222+
if args.ip and args.port != -1:
223+
with Client((args.ip, args.port)) as conn:
224+
conn.send(json.dumps({"Error": str(e)}))
225+
else:
226+
raise Exception(e)

0 commit comments

Comments
 (0)