Skip to content

Commit 3708e74

Browse files
chuntlfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - HardSigmoid follow up for FP16 / Test cases complement (#2790)
Summary: - make HardSigmoid more compact in FP16 - add online prepare utitlity test case - add test case for export_example.py Pull Request resolved: #2790 Reviewed By: kirklandsign Differential Revision: D55617337 Pulled By: cccclai fbshipit-source-id: bb419aedf167a1f9bf2d7ea289e0e3311cd806f9
1 parent 8cabeac commit 3708e74

File tree

5 files changed

+85
-18
lines changed

5 files changed

+85
-18
lines changed

backends/qualcomm/builders/op_dequantize.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,16 @@ def define_node(
5656

5757

5858
@register_node_visitor
59-
class PerTensorDequantizeDefault(DequantizeOpBase):
60-
target = ["quantized_decomposed.dequantize_per_tensor.default"]
59+
class PerTensorDequantize(DequantizeOpBase):
60+
target = [
61+
"quantized_decomposed.dequantize_per_tensor.default",
62+
"quantized_decomposed.dequantize_per_tensor.tensor",
63+
]
6164

6265

6366
@register_node_visitor
64-
class PerTensorDequantizeTensor(DequantizeOpBase):
65-
target = ["quantized_decomposed.dequantize_per_tensor.tensor"]
66-
67-
68-
@register_node_visitor
69-
class PerChannelDequantizeDefault(DequantizeOpBase):
70-
target = ["quantized_decomposed.dequantize_per_channel.default"]
71-
72-
73-
@register_node_visitor
74-
class PerChannelDequantizeTensor(DequantizeOpBase):
75-
target = ["quantized_decomposed.dequantize_per_channel.tensor"]
67+
class PerChannelDequantize(DequantizeOpBase):
68+
target = [
69+
"quantized_decomposed.dequantize_per_channel.default",
70+
"quantized_decomposed.dequantize_per_channel.tensor",
71+
]

backends/qualcomm/passes/convert_hardsigmoid.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def call(self, graph_module: torch.fx.GraphModule):
2525
partitions = get_source_partitions(graph, [torch.nn.Hardsigmoid])
2626
for _, src_partitions in partitions.items():
2727
for src_partition in src_partitions:
28+
if exir_ops.edge.aten.hardswish.default in [
29+
node.target for node in src_partition.nodes
30+
]:
31+
continue
2832
if self.quantization_capture:
2933
# only one hardsigmoid op will be seen
3034
input_nodes = src_partition.input_nodes
@@ -34,8 +38,6 @@ def call(self, graph_module: torch.fx.GraphModule):
3438
else:
3539
in_ops_target = exir_ops.edge.aten.add.Tensor
3640
out_ops_target = exir_ops.edge.aten.div.Tensor
37-
# see the reverse engineering logic hardswish
38-
# https://shorturl.at/pACEL
3941
input_nodes = [
4042
n for n in src_partition.nodes if n.target is in_ops_target
4143
]

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import json
77
import subprocess
88
import sys
9+
import tempfile
910
import unittest
1011
from multiprocessing.connection import Listener
12+
from pathlib import Path
1113

1214
import torch
1315
from executorch.backends.qualcomm.tests.utils import (
@@ -1102,6 +1104,19 @@ def test_qnn_backend_shared_buffer(self):
11021104
expected_partitions=1,
11031105
)
11041106

1107+
def test_qnn_backend_online_prepare(self):
1108+
backend_options = generate_htp_compiler_spec(use_fp16=True)
1109+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1110+
soc_model=self.arch_table[TestQNN.model],
1111+
backend_options=backend_options,
1112+
debug=False,
1113+
saver=False,
1114+
online_prepare=True,
1115+
)
1116+
module = SimpleModel() # noqa: F405
1117+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1118+
self.lower_module_and_test_output(module, sample_input)
1119+
11051120

11061121
class TestQNNQuantizedUtils(TestQNN):
11071122
# TODO: refactor to support different backends
@@ -1223,6 +1238,20 @@ def test_qnn_backend_shared_buffer(self):
12231238
expected_partitions=1,
12241239
)
12251240

1241+
def test_qnn_backend_online_prepare(self):
1242+
backend_options = generate_htp_compiler_spec(use_fp16=False)
1243+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1244+
soc_model=self.arch_table[TestQNN.model],
1245+
backend_options=backend_options,
1246+
debug=False,
1247+
saver=False,
1248+
online_prepare=True,
1249+
)
1250+
module = SimpleModel() # noqa: F405
1251+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1252+
module = self.get_qdq_module(module, sample_input)
1253+
self.lower_module_and_test_output(module, sample_input)
1254+
12261255

12271256
class TestExampleOssScript(TestQNN):
12281257
def required_envs(self, conditions=None) -> bool:
@@ -1640,6 +1669,29 @@ def test_ptq_mobilebert(self):
16401669
for k, v in cpu.items():
16411670
self.assertLessEqual(abs(v[0] - htp[k][0]), 5)
16421671

1672+
def test_export_example(self):
1673+
if not self.required_envs([self.model_name]):
1674+
self.skipTest("missing required envs")
1675+
1676+
with tempfile.TemporaryDirectory() as tmp_dir:
1677+
cmds = [
1678+
"python",
1679+
"qualcomm/scripts/export_example.py",
1680+
"--model_name",
1681+
self.model_name,
1682+
"--output_folder",
1683+
"{}/".format(tmp_dir),
1684+
"--generate_etrecord",
1685+
]
1686+
1687+
p = subprocess.Popen(
1688+
cmds, stdout=subprocess.DEVNULL, cwd=f"{self.executorch_root}/examples"
1689+
)
1690+
p.communicate()
1691+
self.assertTrue(
1692+
Path("{0}/{1}.pte".format(tmp_dir, self.model_name)).exists()
1693+
)
1694+
16431695

16441696
def setup_environment():
16451697
parser = setup_common_args_and_variables()
@@ -1669,6 +1721,12 @@ def setup_environment():
16691721
default="",
16701722
type=str,
16711723
)
1724+
parser.add_argument(
1725+
"-n",
1726+
"--model_name",
1727+
help="Input the model to export",
1728+
type=str,
1729+
)
16721730
parser.add_argument(
16731731
"-o",
16741732
"--online_prepare",
@@ -1697,6 +1755,7 @@ def setup_environment():
16971755
TestQNN.artifact_dir = args.artifact_dir
16981756
TestQNN.image_dataset = args.image_dataset
16991757
TestQNN.pretrained_weight = args.pretrained_weight
1758+
TestQNN.model_name = args.model_name
17001759
TestQNN.online_prepare = args.online_prepare
17011760
TestQNN.enable_profile = args.enable_profile
17021761
TestQNN.error_only = args.error_only

backends/qualcomm/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ConvertBinaryOpsWithScalar,
2020
)
2121
from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul
22+
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
2223
from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import (
2324
ConvertInterpolateWithUpsample2D,
2425
)
@@ -103,6 +104,7 @@ def _transform(edge_program: ExportedProgram) -> None:
103104
graph_module = edge_program.graph_module
104105
RemoveClone()(graph_module)
105106
ConvertToLinear()(graph_module)
107+
ConvertHardsigmoid()(graph_module)
106108
ConvertBmmToMatmul()(graph_module)
107109
ConvertInterpolateWithUpsample2D()(graph_module)
108110
I64toI32(edge_program)(graph_module)

examples/qualcomm/scripts/export_example.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@
4040
help="Generate ETRecord metadata to link with runtime results (used for profiling)",
4141
)
4242

43+
parser.add_argument(
44+
"-f",
45+
"--output_folder",
46+
type=str,
47+
default="",
48+
help="The folder to store the exported program",
49+
)
50+
4351
args = parser.parse_args()
4452

4553
if args.model_name not in MODEL_NAME_TO_MODEL:
@@ -92,7 +100,7 @@
92100
)
93101

94102
if args.generate_etrecord:
95-
etrecord_path = "etrecord.bin"
103+
etrecord_path = args.output_folder + "etrecord.bin"
96104
generate_etrecord(etrecord_path, edge_copy, executorch_program)
97105

98-
save_pte_program(executorch_program, args.model_name)
106+
save_pte_program(executorch_program, args.model_name, args.output_folder)

0 commit comments

Comments
 (0)