Skip to content

Commit fd88d97

Browse files
committed
Qualcomm AI Engine Direct - HardSigmoid follow up for FP16 / Test cases complement
Summary: - make HardSigmoid more compact in FP16 - add online prepare utitlity test case - add test case for export_example.py
1 parent 24477df commit fd88d97

File tree

5 files changed

+85
-18
lines changed

5 files changed

+85
-18
lines changed

backends/qualcomm/builders/op_dequantize.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,16 @@ def define_node(
5454

5555

5656
@register_node_visitor
57-
class PerTensorDequantizeDefault(DequantizeOpBase):
58-
target = ["quantized_decomposed.dequantize_per_tensor.default"]
57+
class PerTensorDequantize(DequantizeOpBase):
58+
target = [
59+
"quantized_decomposed.dequantize_per_tensor.default",
60+
"quantized_decomposed.dequantize_per_tensor.tensor",
61+
]
5962

6063

6164
@register_node_visitor
62-
class PerTensorDequantizeTensor(DequantizeOpBase):
63-
target = ["quantized_decomposed.dequantize_per_tensor.tensor"]
64-
65-
66-
@register_node_visitor
67-
class PerChannelDequantizeDefault(DequantizeOpBase):
68-
target = ["quantized_decomposed.dequantize_per_channel.default"]
69-
70-
71-
@register_node_visitor
72-
class PerChannelDequantizeTensor(DequantizeOpBase):
73-
target = ["quantized_decomposed.dequantize_per_channel.tensor"]
65+
class PerChannelDequantize(DequantizeOpBase):
66+
target = [
67+
"quantized_decomposed.dequantize_per_channel.default",
68+
"quantized_decomposed.dequantize_per_channel.tensor",
69+
]

backends/qualcomm/passes/convert_hardsigmoid.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def call(self, graph_module: torch.fx.GraphModule):
2525
partitions = get_source_partitions(graph, [torch.nn.Hardsigmoid])
2626
for _, src_partitions in partitions.items():
2727
for src_partition in src_partitions:
28+
if exir_ops.edge.aten.hardswish.default in [
29+
node.target for node in src_partition.nodes
30+
]:
31+
continue
2832
if self.quantization_capture:
2933
# only one hardsigmoid op will be seen
3034
input_nodes = src_partition.input_nodes
@@ -34,8 +38,6 @@ def call(self, graph_module: torch.fx.GraphModule):
3438
else:
3539
in_ops_target = exir_ops.edge.aten.add.Tensor
3640
out_ops_target = exir_ops.edge.aten.div.Tensor
37-
# see the reverse engineering logic hardswish
38-
# https://shorturl.at/pACEL
3941
input_nodes = [
4042
n for n in src_partition.nodes if n.target is in_ops_target
4143
]

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import json
77
import subprocess
88
import sys
9+
import tempfile
910
import unittest
1011
from multiprocessing.connection import Listener
12+
from pathlib import Path
1113

1214
import torch
1315
from executorch.backends.qualcomm.tests.utils import (
@@ -1099,6 +1101,19 @@ def test_qnn_backend_shared_buffer(self):
10991101
expected_partitions=1,
11001102
)
11011103

1104+
def test_qnn_backend_online_prepare(self):
1105+
backend_options = generate_htp_compiler_spec(use_fp16=True)
1106+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1107+
soc_model=self.arch_table[TestQNN.model],
1108+
backend_options=backend_options,
1109+
debug=False,
1110+
saver=False,
1111+
online_prepare=True,
1112+
)
1113+
module = SimpleModel() # noqa: F405
1114+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1115+
self.lower_module_and_test_output(module, sample_input)
1116+
11021117

11031118
class TestQNNQuantizedUtils(TestQNN):
11041119
# TODO: refactor to support different backends
@@ -1220,6 +1235,20 @@ def test_qnn_backend_shared_buffer(self):
12201235
expected_partitions=1,
12211236
)
12221237

1238+
def test_qnn_backend_online_prepare(self):
1239+
backend_options = generate_htp_compiler_spec(use_fp16=False)
1240+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1241+
soc_model=self.arch_table[TestQNN.model],
1242+
backend_options=backend_options,
1243+
debug=False,
1244+
saver=False,
1245+
online_prepare=True,
1246+
)
1247+
module = SimpleModel() # noqa: F405
1248+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1249+
module = self.get_qdq_module(module, sample_input)
1250+
self.lower_module_and_test_output(module, sample_input)
1251+
12231252

12241253
class TestExampleScript(TestQNN):
12251254
def required_envs(self, conditions=None) -> bool:
@@ -1586,6 +1615,29 @@ def test_ptq_mobilebert(self):
15861615
for k, v in cpu.items():
15871616
self.assertLessEqual(abs(v[0] - htp[k][0]), 5)
15881617

1618+
def test_export_example(self):
1619+
if not self.required_envs([self.model_name]):
1620+
self.skipTest("missing required envs")
1621+
1622+
with tempfile.TemporaryDirectory() as tmp_dir:
1623+
cmds = [
1624+
"python",
1625+
"qualcomm/scripts/export_example.py",
1626+
"--model_name",
1627+
self.model_name,
1628+
"--output_folder",
1629+
"{}/".format(tmp_dir),
1630+
"--generate_etrecord",
1631+
]
1632+
1633+
p = subprocess.Popen(
1634+
cmds, stdout=subprocess.DEVNULL, cwd=f"{self.executorch_root}/examples"
1635+
)
1636+
p.communicate()
1637+
self.assertTrue(
1638+
Path("{0}/{1}.pte".format(tmp_dir, self.model_name)).exists()
1639+
)
1640+
15891641

15901642
def setup_environment():
15911643
parser = setup_common_args_and_variables()
@@ -1615,6 +1667,12 @@ def setup_environment():
16151667
default="",
16161668
type=str,
16171669
)
1670+
parser.add_argument(
1671+
"-n",
1672+
"--model_name",
1673+
help="Input the model to export",
1674+
type=str,
1675+
)
16181676
parser.add_argument(
16191677
"-o",
16201678
"--online_prepare",
@@ -1643,6 +1701,7 @@ def setup_environment():
16431701
TestQNN.artifact_dir = args.artifact_dir
16441702
TestQNN.image_dataset = args.image_dataset
16451703
TestQNN.pretrained_weight = args.pretrained_weight
1704+
TestQNN.model_name = args.model_name
16461705
TestQNN.online_prepare = args.online_prepare
16471706
TestQNN.enable_profile = args.enable_profile
16481707
TestQNN.error_only = args.error_only

backends/qualcomm/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ConvertBinaryOpsWithScalar,
2020
)
2121
from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul
22+
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
2223
from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import (
2324
ConvertInterpolateWithUpsample2D,
2425
)
@@ -104,6 +105,7 @@ def _transform(edge_program: ExportedProgram) -> None:
104105
graph_module = edge_program.graph_module
105106
RemoveClone()(graph_module)
106107
ConvertToLinear()(graph_module)
108+
ConvertHardsigmoid()(graph_module)
107109
ConvertBmmToMatmul()(graph_module)
108110
ConvertInterpolateWithUpsample2D()(graph_module)
109111
I64toI32(edge_program)(graph_module)

examples/qualcomm/scripts/export_example.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@
4040
help="Generate ETRecord metadata to link with runtime results (used for profiling)",
4141
)
4242

43+
parser.add_argument(
44+
"-f",
45+
"--output_folder",
46+
type=str,
47+
default="",
48+
help="The folder to store the exported program",
49+
)
50+
4351
args = parser.parse_args()
4452

4553
if args.model_name not in MODEL_NAME_TO_MODEL:
@@ -92,7 +100,7 @@
92100
)
93101

94102
if args.generate_etrecord:
95-
etrecord_path = "etrecord.bin"
103+
etrecord_path = args.output_folder + "etrecord.bin"
96104
generate_etrecord(etrecord_path, edge_copy, executorch_program)
97105

98-
save_pte_program(executorch_program, args.model_name)
106+
save_pte_program(executorch_program, args.model_name, args.output_folder)

0 commit comments

Comments
 (0)