Skip to content

Commit 2f8ae11

Browse files
committed
rebase & code review #1
1 parent 6525201 commit 2f8ae11

File tree

10 files changed

+88
-103
lines changed

10 files changed

+88
-103
lines changed

backends/qualcomm/passes/expand_broadcast_tensor_shape.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ class ExpandBroadcastTensorShape(ExportPass):
1717

1818
def __init__(self):
1919
super(ExpandBroadcastTensorShape, self).__init__()
20-
self.binary_op_targets = [
20+
self.broadcast_op_targets = [
2121
exir_ops.edge.aten.add.Tensor,
2222
exir_ops.edge.aten.sub.Tensor,
2323
exir_ops.edge.aten.mul.Tensor,
2424
exir_ops.edge.aten.div.Tensor,
2525
]
2626

27-
def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
27+
def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
2828
for node in graph_module.graph.nodes:
29-
if node.target in self.binary_op_targets:
29+
if node.target in self.broadcast_op_targets:
3030
for arg in node.args:
3131
input_rank = len(arg.meta["val"].shape)
3232
output_rank = len(node.meta["val"].shape)
@@ -52,7 +52,7 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
5252
user.replace_input_with(arg, reshape_node)
5353

5454
def call(self, graph_module: torch.fx.GraphModule):
55-
self._traverse_binary_node(graph_module)
55+
self.traverse_broadcast_node(graph_module)
5656
graph_module.recompile()
5757
dead_code_elimination_pass(graph_module)
5858
return PassResult(graph_module, True)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 69 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,13 +1742,12 @@ def required_envs(self, conditions=None) -> bool:
17421742
]
17431743
)
17441744

1745-
def test_fbnet(self):
1745+
def test_dino_v2(self):
17461746
if not self.required_envs([self.image_dataset]):
17471747
self.skipTest("missing required envs")
1748-
17491748
cmds = [
17501749
"python",
1751-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/fbnet.py",
1750+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/dino_v2.py",
17521751
"--dataset",
17531752
self.image_dataset,
17541753
"--artifact",
@@ -1775,18 +1774,16 @@ def test_fbnet(self):
17751774
if "Error" in msg:
17761775
self.fail(msg["Error"])
17771776
else:
1778-
self.assertGreaterEqual(msg["top_1"], 60)
1779-
self.assertGreaterEqual(msg["top_5"], 90)
1777+
self.assertGreaterEqual(msg["top_1"], 70)
1778+
self.assertGreaterEqual(msg["top_5"], 85)
17801779

1781-
def test_gMLP(self):
1782-
if not self.required_envs([self.image_dataset]):
1780+
def test_esrgan(self):
1781+
if not self.required_envs():
17831782
self.skipTest("missing required envs")
17841783

17851784
cmds = [
17861785
"python",
1787-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/gMLP_image_classification.py",
1788-
"--dataset",
1789-
self.image_dataset,
1786+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/esrgan.py",
17901787
"--artifact",
17911788
self.artifact_dir,
17921789
"--build_folder",
@@ -1795,6 +1792,9 @@ def test_gMLP(self):
17951792
self.device,
17961793
"--model",
17971794
self.model,
1795+
"--default_dataset",
1796+
"--oss_repo",
1797+
self.oss_repo,
17981798
"--ip",
17991799
self.ip,
18001800
"--port",
@@ -1811,17 +1811,17 @@ def test_gMLP(self):
18111811
if "Error" in msg:
18121812
self.fail(msg["Error"])
18131813
else:
1814-
self.assertGreaterEqual(msg["top_1"], 60)
1815-
self.assertGreaterEqual(msg["top_5"], 90)
1814+
self.assertGreaterEqual(msg["PSNR"], 24)
1815+
self.assertGreaterEqual(msg["SSIM"], 0.8)
18161816

1817-
def test_regnet(self):
1818-
if not self.required_envs([self.image_dataset]):
1817+
def test_fastvit(self):
1818+
if not self.required_envs(
1819+
[self.image_dataset, self.pretrained_weight, self.oss_repo]
1820+
):
18191821
self.skipTest("missing required envs")
1820-
1821-
weights = ["regnet_y_400mf", "regnet_x_400mf"]
18221822
cmds = [
18231823
"python",
1824-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/regnet.py",
1824+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/fastvit.py",
18251825
"--dataset",
18261826
self.image_dataset,
18271827
"--artifact",
@@ -1832,6 +1832,10 @@ def test_regnet(self):
18321832
self.device,
18331833
"--model",
18341834
self.model,
1835+
"--oss_repo",
1836+
self.oss_repo,
1837+
"--pretrained_weight",
1838+
self.pretrained_weight,
18351839
"--ip",
18361840
self.ip,
18371841
"--port",
@@ -1840,27 +1844,26 @@ def test_regnet(self):
18401844
if self.host:
18411845
cmds.extend(["--host", self.host])
18421846

1843-
for weight in weights:
1844-
p = subprocess.Popen(
1845-
cmds + ["--weights", weight], stdout=subprocess.DEVNULL
1846-
)
1847-
with Listener((self.ip, self.port)) as listener:
1848-
conn = listener.accept()
1849-
p.communicate()
1850-
msg = json.loads(conn.recv())
1851-
if "Error" in msg:
1852-
self.fail(msg["Error"])
1853-
else:
1854-
self.assertGreaterEqual(msg["top_1"], 60)
1855-
self.assertGreaterEqual(msg["top_5"], 85)
1847+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
1848+
with Listener((self.ip, self.port)) as listener:
1849+
conn = listener.accept()
1850+
p.communicate()
1851+
msg = json.loads(conn.recv())
1852+
if "Error" in msg:
1853+
self.fail(msg["Error"])
1854+
else:
1855+
self.assertGreaterEqual(msg["top_1"], 60)
1856+
self.assertGreaterEqual(msg["top_5"], 80)
18561857

1857-
def test_ssd300_vgg16(self):
1858-
if not self.required_envs([self.pretrained_weight, self.oss_repo]):
1858+
def test_fbnet(self):
1859+
if not self.required_envs([self.image_dataset]):
18591860
self.skipTest("missing required envs")
18601861

18611862
cmds = [
18621863
"python",
1863-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py",
1864+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/fbnet.py",
1865+
"--dataset",
1866+
self.image_dataset,
18641867
"--artifact",
18651868
self.artifact_dir,
18661869
"--build_folder",
@@ -1869,10 +1872,6 @@ def test_ssd300_vgg16(self):
18691872
self.device,
18701873
"--model",
18711874
self.model,
1872-
"--oss_repo",
1873-
self.oss_repo,
1874-
"--pretrained_weight",
1875-
self.pretrained_weight,
18761875
"--ip",
18771876
self.ip,
18781877
"--port",
@@ -1889,14 +1888,16 @@ def test_ssd300_vgg16(self):
18891888
if "Error" in msg:
18901889
self.fail(msg["Error"])
18911890
else:
1892-
self.assertGreaterEqual(msg["mAP"], 0.70)
1891+
self.assertGreaterEqual(msg["top_1"], 60)
1892+
self.assertGreaterEqual(msg["top_5"], 90)
18931893

1894-
def test_dino_v2(self):
1894+
def test_gMLP(self):
18951895
if not self.required_envs([self.image_dataset]):
18961896
self.skipTest("missing required envs")
1897+
18971898
cmds = [
18981899
"python",
1899-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/dino_v2.py",
1900+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/gMLP_image_classification.py",
19001901
"--dataset",
19011902
self.image_dataset,
19021903
"--artifact",
@@ -1923,16 +1924,19 @@ def test_dino_v2(self):
19231924
if "Error" in msg:
19241925
self.fail(msg["Error"])
19251926
else:
1926-
self.assertGreaterEqual(msg["top_1"], 70)
1927-
self.assertGreaterEqual(msg["top_5"], 85)
1927+
self.assertGreaterEqual(msg["top_1"], 60)
1928+
self.assertGreaterEqual(msg["top_5"], 90)
19281929

1929-
def test_esrgan(self):
1930-
if not self.required_envs():
1930+
def test_regnet(self):
1931+
if not self.required_envs([self.image_dataset]):
19311932
self.skipTest("missing required envs")
19321933

1934+
weights = ["regnet_y_400mf", "regnet_x_400mf"]
19331935
cmds = [
19341936
"python",
1935-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/esrgan.py",
1937+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/regnet.py",
1938+
"--dataset",
1939+
self.image_dataset,
19361940
"--artifact",
19371941
self.artifact_dir,
19381942
"--build_folder",
@@ -1941,9 +1945,6 @@ def test_esrgan(self):
19411945
self.device,
19421946
"--model",
19431947
self.model,
1944-
"--default_dataset",
1945-
"--oss_repo",
1946-
self.oss_repo,
19471948
"--ip",
19481949
self.ip,
19491950
"--port",
@@ -1952,16 +1953,19 @@ def test_esrgan(self):
19521953
if self.host:
19531954
cmds.extend(["--host", self.host])
19541955

1955-
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
1956-
with Listener((self.ip, self.port)) as listener:
1957-
conn = listener.accept()
1958-
p.communicate()
1959-
msg = json.loads(conn.recv())
1960-
if "Error" in msg:
1961-
self.fail(msg["Error"])
1962-
else:
1963-
self.assertGreaterEqual(msg["PSNR"], 24)
1964-
self.assertGreaterEqual(msg["SSIM"], 0.8)
1956+
for weight in weights:
1957+
p = subprocess.Popen(
1958+
cmds + ["--weights", weight], stdout=subprocess.DEVNULL
1959+
)
1960+
with Listener((self.ip, self.port)) as listener:
1961+
conn = listener.accept()
1962+
p.communicate()
1963+
msg = json.loads(conn.recv())
1964+
if "Error" in msg:
1965+
self.fail(msg["Error"])
1966+
else:
1967+
self.assertGreaterEqual(msg["top_1"], 60)
1968+
self.assertGreaterEqual(msg["top_5"], 85)
19651969

19661970
def test_squeezenet(self):
19671971
if not self.required_envs([self.image_dataset]):
@@ -1996,19 +2000,16 @@ def test_squeezenet(self):
19962000
if "Error" in msg:
19972001
self.fail(msg["Error"])
19982002
else:
1999-
self.assertGreaterEqual(msg["top_1"], 50)
2000-
self.assertGreaterEqual(msg["top_5"], 75)
2003+
self.assertGreaterEqual(msg["top_1"], 45)
2004+
self.assertGreaterEqual(msg["top_5"], 70)
20012005

2002-
def test_fastvit(self):
2003-
if not self.required_envs(
2004-
[self.image_dataset, self.pretrained_weight, self.oss_repo]
2005-
):
2006+
def test_ssd300_vgg16(self):
2007+
if not self.required_envs([self.pretrained_weight, self.oss_repo]):
20062008
self.skipTest("missing required envs")
2009+
20072010
cmds = [
20082011
"python",
2009-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/fastvit.py",
2010-
"--dataset",
2011-
self.image_dataset,
2012+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py",
20122013
"--artifact",
20132014
self.artifact_dir,
20142015
"--build_folder",
@@ -2037,8 +2038,7 @@ def test_fastvit(self):
20372038
if "Error" in msg:
20382039
self.fail(msg["Error"])
20392040
else:
2040-
self.assertGreaterEqual(msg["top_1"], 60)
2041-
self.assertGreaterEqual(msg["top_5"], 80)
2041+
self.assertGreaterEqual(msg["mAP"], 0.70)
20422042

20432043

20442044
class TestExampleQaihubScript(TestQNN):

examples/qualcomm/oss_scripts/dino_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def main(args):
6565
skip_node_id_set=skip_node_id_set,
6666
skip_node_op_set=skip_node_op_set,
6767
quant_dtype=QuantDtype.use_8a8w,
68+
shared_buffer=args.shared_buffer,
6869
)
6970

7071
if args.compile_only:

examples/qualcomm/oss_scripts/esrgan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def main(args):
6868
skip_node_id_set=skip_node_id_set,
6969
skip_node_op_set=skip_node_op_set,
7070
quant_dtype=QuantDtype.use_8a8w,
71+
shared_buffer=args.shared_buffer,
7172
)
7273

7374
if args.compile_only:

examples/qualcomm/oss_scripts/fastvit.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,10 @@
2525
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
2626
from executorch.examples.qualcomm.utils import (
2727
build_executorch_binary,
28-
convert_pt2e,
2928
get_imagenet_dataset,
3029
make_output_dir,
3130
make_quantizer,
3231
parse_skip_delegation_node,
33-
prepare_pt2e,
3432
setup_common_args_and_variables,
3533
SimpleADB,
3634
topk_accuracy,
@@ -72,10 +70,7 @@ def main(args):
7270
)
7371

7472
pte_filename = "fastvit_qnn"
75-
quantizer = make_quantizer(
76-
quant_dtype=QuantDtype.use_8a8w,
77-
per_channel_conv=True,
78-
)
73+
quantizer = make_quantizer(quant_dtype=QuantDtype.use_8a8w)
7974

8075
# there are lots of outliers appearing in fastvit parameters
8176
# we need to apply special configuration to saturate their impact
@@ -111,27 +106,19 @@ def main(args):
111106
weight=q_config.weight,
112107
bias=q_config.bias,
113108
)
114-
115-
# perform ptq
116-
model = convert_linear_to_conv2d(
117-
get_instance(args.oss_repo, args.pretrained_weight)
118-
)
119-
captured_model = torch.export.export(model, inputs[0]).module()
120-
annotated_model = prepare_pt2e(captured_model, quantizer)
121-
for input in inputs:
122-
annotated_model(*input)
123-
quantized_model = convert_pt2e(annotated_model)
124-
125109
# lower to QNN
126110
build_executorch_binary(
127-
quantized_model,
111+
convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)),
128112
inputs[0],
129113
args.model,
130114
f"{args.artifact}/{pte_filename}",
131-
dataset=None,
115+
dataset=inputs,
132116
skip_node_id_set=skip_node_id_set,
133117
skip_node_op_set=skip_node_op_set,
118+
quant_dtype=QuantDtype.use_8a8w,
119+
custom_quantizer=quantizer,
134120
custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE},
121+
shared_buffer=args.shared_buffer,
135122
)
136123

137124
if args.compile_only:
@@ -181,7 +168,7 @@ def main(args):
181168
"-a",
182169
"--artifact",
183170
help="path for storing generated artifacts by this example. Default ./fastvit",
184-
default="./esrgan",
171+
default="./fastvit",
185172
type=str,
186173
)
187174

examples/qualcomm/oss_scripts/fbnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def main(args):
5050
f"{args.artifact}/{pte_filename}",
5151
inputs,
5252
quant_dtype=QuantDtype.use_8a8w,
53+
shared_buffer=args.shared_buffer,
5354
)
5455

5556
if args.compile_only:

examples/qualcomm/oss_scripts/regnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def main(args):
6666
skip_node_id_set=skip_node_id_set,
6767
skip_node_op_set=skip_node_op_set,
6868
quant_dtype=QuantDtype.use_8a8w,
69+
shared_buffer=args.shared_buffer,
6970
)
7071

7172
if args.compile_only:

examples/qualcomm/oss_scripts/squeezenet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def main(args):
5757
skip_node_id_set=skip_node_id_set,
5858
skip_node_op_set=skip_node_op_set,
5959
quant_dtype=QuantDtype.use_8a8w,
60+
shared_buffer=args.shared_buffer,
6061
)
6162

6263
if args.compile_only:

0 commit comments

Comments
 (0)