Skip to content

Commit e650dd9

Browse files
haowhsu-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - SqueezeNet Enablement (#3748)
Summary: Summary - OSS model enablement: squeezenet Pull Request resolved: #3748 Reviewed By: kirklandsign Differential Revision: D57896519 Pulled By: cccclai fbshipit-source-id: 2afacd4272891f0914a39855e3e7d67d166d2076
1 parent 9e86860 commit e650dd9

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,7 @@ def test_dino_v2(self):
15241524
]
15251525
if self.host:
15261526
cmds.extend(["--host", self.host])
1527+
15271528
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
15281529
with Listener((self.ip, self.port)) as listener:
15291530
conn = listener.accept()
@@ -1566,6 +1567,39 @@ def test_esrgan(self):
15661567
self.assertGreaterEqual(msg["PSNR"], 24)
15671568
self.assertGreaterEqual(msg["SSIM"], 0.8)
15681569

1570+
def test_squeezenet(self):
1571+
if not self.required_envs([self.image_dataset]):
1572+
self.skipTest("missing required envs")
1573+
1574+
cmds = [
1575+
"python",
1576+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/squeezenet.py",
1577+
"--dataset",
1578+
self.image_dataset,
1579+
"--artifact",
1580+
self.artifact_dir,
1581+
"--build_folder",
1582+
self.build_folder,
1583+
"--device",
1584+
self.device,
1585+
"--model",
1586+
self.model,
1587+
"--ip",
1588+
self.ip,
1589+
"--port",
1590+
str(self.port),
1591+
]
1592+
if self.host:
1593+
cmds.extend(["--host", self.host])
1594+
1595+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
1596+
with Listener((self.ip, self.port)) as listener:
1597+
conn = listener.accept()
1598+
p.communicate()
1599+
msg = json.loads(conn.recv())
1600+
self.assertGreaterEqual(msg["top_1"], 40)
1601+
self.assertGreaterEqual(msg["top_5"], 70)
1602+
15691603

15701604
class TestExampleScript(TestQNN):
15711605
def required_envs(self, conditions=None) -> bool:
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import os
9+
import sys
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
14+
import torch
15+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16+
from executorch.examples.qualcomm.scripts.utils import (
17+
build_executorch_binary,
18+
make_output_dir,
19+
parse_skip_delegation_node,
20+
setup_common_args_and_variables,
21+
SimpleADB,
22+
topk_accuracy,
23+
)
24+
25+
26+
def get_dataset(dataset_path, data_size):
27+
from torchvision import datasets, transforms
28+
29+
def get_data_loader():
30+
preprocess = transforms.Compose(
31+
[
32+
transforms.Resize(256),
33+
transforms.CenterCrop(224),
34+
transforms.ToTensor(),
35+
transforms.Normalize(
36+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
37+
),
38+
]
39+
)
40+
imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
41+
return torch.utils.data.DataLoader(
42+
imagenet_data,
43+
shuffle=True,
44+
)
45+
46+
# prepare input data
47+
inputs, targets, input_list = [], [], ""
48+
data_loader = get_data_loader()
49+
for index, data in enumerate(data_loader):
50+
if index >= data_size:
51+
break
52+
feature, target = data
53+
inputs.append((feature,))
54+
targets.append(target)
55+
input_list += f"input_{index}_0.raw\n"
56+
57+
return inputs, targets, input_list
58+
59+
60+
if __name__ == "__main__":
61+
parser = setup_common_args_and_variables()
62+
63+
parser.add_argument(
64+
"-d",
65+
"--dataset",
66+
help=(
67+
"path to the validation folder of ImageNet dataset. "
68+
"e.g. --dataset imagenet-mini/val "
69+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
70+
),
71+
type=str,
72+
required=True,
73+
)
74+
75+
parser.add_argument(
76+
"-a",
77+
"--artifact",
78+
help="path for storing generated artifacts by this example. "
79+
"Default ./squeezenet",
80+
default="./squeezenet",
81+
type=str,
82+
)
83+
84+
args = parser.parse_args()
85+
86+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
87+
88+
# ensure the working directory exist.
89+
os.makedirs(args.artifact, exist_ok=True)
90+
91+
if not args.compile_only and args.device is None:
92+
raise RuntimeError(
93+
"device serial is required if not compile only. "
94+
"Please specify a device serial by -s/--device argument."
95+
)
96+
97+
data_num = 100
98+
inputs, targets, input_list = get_dataset(
99+
dataset_path=f"{args.dataset}",
100+
data_size=data_num,
101+
)
102+
pte_filename = "squeezenet_qnn"
103+
instance = torch.hub.load(
104+
"pytorch/vision:v0.10.0", "squeezenet1_1", pretrained=True
105+
)
106+
build_executorch_binary(
107+
instance.eval(),
108+
(torch.randn(1, 3, 224, 224),),
109+
args.model,
110+
f"{args.artifact}/{pte_filename}",
111+
inputs,
112+
skip_node_id_set=skip_node_id_set,
113+
skip_node_op_set=skip_node_op_set,
114+
quant_dtype=QuantDtype.use_16a16w,
115+
)
116+
117+
if args.compile_only:
118+
sys.exit(0)
119+
120+
# setup required paths accordingly
121+
# qnn_sdk : QNN SDK path setup in environment variable
122+
# artifact_path : path where artifacts were built
123+
# pte_path : path where executorch binary was stored
124+
# device_id : serial number of android device
125+
# workspace : folder for storing artifacts on android device
126+
adb = SimpleADB(
127+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
128+
artifact_path=f"{args.build_folder}",
129+
pte_path=f"{args.artifact}/{pte_filename}.pte",
130+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
131+
device_id=args.device,
132+
host_id=args.host,
133+
soc_model=args.model,
134+
)
135+
adb.push(inputs=inputs, input_list=input_list)
136+
adb.execute()
137+
138+
# collect output data
139+
output_data_folder = f"{args.artifact}/outputs"
140+
make_output_dir(output_data_folder)
141+
142+
adb.pull(output_path=args.artifact)
143+
144+
# top-k analysis
145+
predictions = []
146+
for i in range(data_num):
147+
predictions.append(
148+
np.fromfile(
149+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
150+
)
151+
)
152+
153+
k_val = [1, 5]
154+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
155+
if args.ip and args.port != -1:
156+
with Client((args.ip, args.port)) as conn:
157+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
158+
else:
159+
for i, k in enumerate(k_val):
160+
print(f"top_{k}->{topk[i]}%")

0 commit comments

Comments
 (0)