Skip to content

Commit a79b1a6

Browse files
authored
Support regnet_x_400mf and regnet_y_400mf (#4925)
1 parent 5395ae6 commit a79b1a6

File tree

15 files changed

+222
-78
lines changed

15 files changed

+222
-78
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,46 @@ def test_gMLP(self):
16681668
self.assertGreaterEqual(msg["top_1"], 60)
16691669
self.assertGreaterEqual(msg["top_5"], 90)
16701670

1671+
def test_regnet(self):
1672+
if not self.required_envs([self.image_dataset]):
1673+
self.skipTest("missing required envs")
1674+
1675+
weights = ["regnet_y_400mf", "regnet_x_400mf"]
1676+
cmds = [
1677+
"python",
1678+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/regnet.py",
1679+
"--dataset",
1680+
self.image_dataset,
1681+
"--artifact",
1682+
self.artifact_dir,
1683+
"--build_folder",
1684+
self.build_folder,
1685+
"--device",
1686+
self.device,
1687+
"--model",
1688+
self.model,
1689+
"--ip",
1690+
self.ip,
1691+
"--port",
1692+
str(self.port),
1693+
]
1694+
if self.host:
1695+
cmds.extend(["--host", self.host])
1696+
1697+
for weight in weights:
1698+
p = subprocess.Popen(
1699+
cmds + ["--weights", weight], stdout=subprocess.DEVNULL
1700+
)
1701+
with Listener((self.ip, self.port)) as listener:
1702+
conn = listener.accept()
1703+
p.communicate()
1704+
msg = json.loads(conn.recv())
1705+
if "Error" in msg:
1706+
self.fail(msg["Error"])
1707+
else:
1708+
self.assertGreaterEqual(msg["top_1"], 60)
1709+
self.assertGreaterEqual(msg["top_5"], 85)
1710+
16711711
def test_ssd300_vgg16(self):
16721712
if not self.required_envs([self.pretrained_weight, self.oss_repo]):
16731713
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/dino_v2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,6 @@ def main(args):
105105
if args.compile_only:
106106
sys.exit(0)
107107

108-
# setup required paths accordingly
109-
# qnn_sdk : QNN SDK path setup in environment variable
110-
# build_path : path where QNN delegate artifacts were built
111-
# pte_path : path where executorch binary was stored
112-
# device_id : serial number of android device
113-
# workspace : folder for storing artifacts on android device
114108
adb = SimpleADB(
115109
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
116110
build_path=f"{args.build_folder}",

examples/qualcomm/oss_scripts/esrgan.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@ def main(args):
7474
if args.compile_only:
7575
sys.exit(0)
7676

77-
# setup required paths accordingly
78-
# qnn_sdk : QNN SDK path setup in environment variable
79-
# build_path : path where QNN delegate artifacts were built
80-
# pte_path : path where executorch binary was stored
81-
# device_id : serial number of android device
82-
# workspace : folder for storing artifacts on android device
8377
adb = SimpleADB(
8478
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
8579
build_path=f"{args.build_folder}",

examples/qualcomm/oss_scripts/gMLP_image_classification.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,6 @@ def main(args):
9696
if args.compile_only:
9797
sys.exit(0)
9898

99-
# setup required paths accordingly
100-
# qnn_sdk : QNN SDK path setup in environment variable
101-
# build_path : path where artifacts were built
102-
# pte_path : path where QNN delegate executorch binary was stored
103-
# device_id : serial number of android device
104-
# workspace : folder for storing artifacts on android device
10599
adb = SimpleADB(
106100
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
107101
build_path=f"{args.build_folder}",
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import os
9+
import sys
10+
from multiprocessing.connection import Client
11+
12+
import numpy as np
13+
import torch
14+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
15+
from executorch.examples.qualcomm.utils import (
16+
build_executorch_binary,
17+
make_output_dir,
18+
parse_skip_delegation_node,
19+
setup_common_args_and_variables,
20+
SimpleADB,
21+
topk_accuracy,
22+
)
23+
24+
from torchvision.models import (
25+
regnet_x_400mf,
26+
RegNet_X_400MF_Weights,
27+
regnet_y_400mf,
28+
RegNet_Y_400MF_Weights,
29+
)
30+
31+
32+
def get_dataset(dataset_path, data_size):
33+
from torchvision import datasets, transforms
34+
35+
def get_data_loader():
36+
preprocess = transforms.Compose(
37+
[
38+
transforms.Resize(256),
39+
transforms.CenterCrop(224),
40+
transforms.ToTensor(),
41+
transforms.Normalize(
42+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
43+
),
44+
]
45+
)
46+
imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
47+
return torch.utils.data.DataLoader(
48+
imagenet_data,
49+
shuffle=True,
50+
)
51+
52+
# prepare input data
53+
inputs, targets, input_list = [], [], ""
54+
data_loader = get_data_loader()
55+
for index, data in enumerate(data_loader):
56+
if index >= data_size:
57+
break
58+
feature, target = data
59+
inputs.append((feature,))
60+
for element in target:
61+
targets.append(element)
62+
input_list += f"input_{index}_0.raw\n"
63+
64+
return inputs, targets, input_list
65+
66+
67+
def main(args):
68+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
69+
70+
# ensure the working directory exist.
71+
os.makedirs(args.artifact, exist_ok=True)
72+
73+
if not args.compile_only and args.device is None:
74+
raise RuntimeError(
75+
"device serial is required if not compile only. "
76+
"Please specify a device serial by -s/--device argument."
77+
)
78+
79+
data_num = 100
80+
inputs, targets, input_list = get_dataset(
81+
dataset_path=f"{args.dataset}",
82+
data_size=data_num,
83+
)
84+
85+
if args.weights == "regnet_y_400mf":
86+
weights = RegNet_Y_400MF_Weights.DEFAULT
87+
model = regnet_y_400mf(weights=weights).eval()
88+
pte_filename = "regnet_y_400mf"
89+
else:
90+
weights = RegNet_X_400MF_Weights.DEFAULT
91+
model = regnet_x_400mf(weights=weights).eval()
92+
pte_filename = "regnet_x_400mf"
93+
94+
build_executorch_binary(
95+
model,
96+
inputs[0],
97+
args.model,
98+
f"{args.artifact}/{pte_filename}",
99+
inputs,
100+
quant_dtype=QuantDtype.use_8a8w,
101+
)
102+
103+
if args.compile_only:
104+
sys.exit(0)
105+
106+
adb = SimpleADB(
107+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
108+
build_path=f"{args.build_folder}",
109+
pte_path=f"{args.artifact}/{pte_filename}.pte",
110+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
111+
device_id=args.device,
112+
host_id=args.host,
113+
soc_model=args.model,
114+
)
115+
adb.push(inputs=inputs, input_list=input_list)
116+
adb.execute()
117+
118+
# collect output data
119+
output_data_folder = f"{args.artifact}/outputs"
120+
make_output_dir(output_data_folder)
121+
122+
adb.pull(output_path=args.artifact)
123+
124+
# top-k analysis
125+
predictions = []
126+
for i in range(data_num):
127+
predictions.append(
128+
np.fromfile(
129+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
130+
)
131+
)
132+
133+
k_val = [1, 5]
134+
topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
135+
if args.ip and args.port != -1:
136+
with Client((args.ip, args.port)) as conn:
137+
conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
138+
else:
139+
for i, k in enumerate(k_val):
140+
print(f"top_{k}->{topk[i]}%")
141+
142+
143+
if __name__ == "__main__":
144+
parser = setup_common_args_and_variables()
145+
parser.add_argument(
146+
"-a",
147+
"--artifact",
148+
help="path for storing generated artifacts by this example. Default ./regnet",
149+
default="./regnet",
150+
type=str,
151+
)
152+
153+
parser.add_argument(
154+
"-d",
155+
"--dataset",
156+
help=(
157+
"path to the validation folder of ImageNet dataset. "
158+
"e.g. --dataset imagenet-mini/val "
159+
"for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
160+
),
161+
type=str,
162+
required=True,
163+
)
164+
165+
parser.add_argument(
166+
"--weights",
167+
type=str,
168+
choices=["regnet_y_400mf", "regnet_x_400mf"],
169+
help="Specify which regent weights/model to execute",
170+
required=True,
171+
)
172+
173+
args = parser.parse_args()
174+
try:
175+
main(args)
176+
except Exception as e:
177+
if args.ip and args.port != -1:
178+
with Client((args.ip, args.port)) as conn:
179+
conn.send(json.dumps({"Error": str(e)}))
180+
else:
181+
raise Exception(e)

examples/qualcomm/oss_scripts/squeezenet.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ def main(args):
9292
if args.compile_only:
9393
sys.exit(0)
9494

95-
# setup required paths accordingly
96-
# qnn_sdk : QNN SDK path setup in environment variable
97-
# build_path : path where QNN delegate artifacts were built
98-
# pte_path : path where executorch binary was stored
99-
# device_id : serial number of android device
100-
# workspace : folder for storing artifacts on android device
10195
adb = SimpleADB(
10296
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
10397
build_path=f"{args.build_folder}",

examples/qualcomm/oss_scripts/ssd300_vgg16.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,6 @@ def main(args):
155155
if args.compile_only:
156156
sys.exit(0)
157157

158-
# setup required paths accordingly
159-
# qnn_sdk : QNN SDK path setup in environment variable
160-
# build_path : path where QNN delegate artifacts were built
161-
# pte_path : path where executorch binary was stored
162-
# device_id : serial number of android device
163-
# workspace : folder for storing artifacts on android device
164158
adb = SimpleADB(
165159
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
166160
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/deeplab_v3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,6 @@ def main(args):
9595
if args.compile_only:
9696
sys.exit(0)
9797

98-
# setup required paths accordingly
99-
# qnn_sdk : QNN SDK path setup in environment variable
100-
# build_path : path where QNN delegate artifacts were built
101-
# pte_path : path where executorch binary was stored
102-
# device_id : serial number of android device
103-
# workspace : folder for storing artifacts on android device
10498
adb = SimpleADB(
10599
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
106100
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/edsr.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,6 @@ def main(args):
126126
if args.compile_only:
127127
sys.exit(0)
128128

129-
# setup required paths accordingly
130-
# qnn_sdk : QNN SDK path setup in environment variable
131-
# build_path : path where QNN delegate artifacts were built
132-
# pte_path : path where executorch binary was stored
133-
# device_id : serial number of android device
134-
# workspace : folder for storing artifacts on android device
135129
adb = SimpleADB(
136130
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
137131
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/inception_v3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ def main(args):
9292
if args.compile_only:
9393
sys.exit(0)
9494

95-
# setup required paths accordingly
96-
# qnn_sdk : QNN SDK path setup in environment variable
97-
# build_path : path where QNN delegate artifacts were built
98-
# pte_path : path where executorch binary was stored
99-
# device_id : serial number of android device
100-
# workspace : folder for storing artifacts on android device
10195
adb = SimpleADB(
10296
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
10397
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/inception_v4.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,6 @@ def main(args):
9191
if args.compile_only:
9292
sys.exit(0)
9393

94-
# setup required paths accordingly
95-
# qnn_sdk : QNN SDK path setup in environment variable
96-
# build_path : path where QNN delegate artifacts were built
97-
# pte_path : path where executorch binary was stored
98-
# device_id : serial number of android device
99-
# workspace : folder for storing artifacts on android device
10094
adb = SimpleADB(
10195
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
10296
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/mobilebert_fine_tune.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,12 +268,6 @@ def main(args):
268268
if args.compile_only:
269269
sys.exit(0)
270270

271-
# setup required paths accordingly
272-
# qnn_sdk : QNN SDK path setup in environment variable
273-
# build_path : path where QNN delegate artifacts were built
274-
# pte_path : path where executorch binary was stored
275-
# device_id : serial number of android device
276-
# workspace : folder for storing artifacts on android device
277271
adb = SimpleADB(
278272
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
279273
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/mobilenet_v2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ def main(args):
9292
if args.compile_only:
9393
sys.exit(0)
9494

95-
# setup required paths accordingly
96-
# qnn_sdk : QNN SDK path setup in environment variable
97-
# build_path : path where QNN delegate artifacts were built
98-
# pte_path : path where executorch binary was stored
99-
# device_id : serial number of android device
100-
# workspace : folder for storing artifacts on android device
10195
adb = SimpleADB(
10296
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
10397
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/mobilenet_v3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,6 @@ def main(args):
9090
if args.compile_only:
9191
sys.exit(0)
9292

93-
# setup required paths accordingly
94-
# qnn_sdk : QNN SDK path setup in environment variable
95-
# build_path : path where QNN delegate artifacts were built
96-
# pte_path : path where executorch binary was stored
97-
# device_id : serial number of android device
98-
# workspace : folder for storing artifacts on android device
9993
adb = SimpleADB(
10094
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
10195
build_path=f"{args.build_folder}",

examples/qualcomm/scripts/torchvision_vit.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,7 @@ def main(args):
7676
quant_dtype=QuantDtype.use_8a8w,
7777
shared_buffer=args.shared_buffer,
7878
)
79-
# setup required paths accordingly
80-
# qnn_sdk : QNN SDK path setup in environment variable
81-
# build_path : path where QNN delegate artifacts were built
82-
# pte_path : path where executorch binary was stored
83-
# device_id : serial number of android device
84-
# workspace : folder for storing artifacts on android device
79+
8580
adb = SimpleADB(
8681
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
8782
build_path=f"{args.build_folder}",

0 commit comments

Comments
 (0)