Skip to content

Commit 1a4c77c

Browse files
committed
Qualcomm AI Engine Direct - GA FocalNet
1 parent 4e38f4a commit 1a4c77c

19 files changed

+284
-39
lines changed

backends/qualcomm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
253253

254254
pybind11_extension(PyQnnManagerAdaptor)
255255
pybind11_extension(PyQnnWrapperAdaptor)
256-
if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo)
256+
if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES RelWithDebInfo)
257257
# Strip unnecessary sections of the binary
258258
pybind11_strip(PyQnnManagerAdaptor)
259259
pybind11_strip(PyQnnWrapperAdaptor)

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .annotate_adaptive_avg_pool1d import AnnotateAdaptiveAvgPool1D
78
from .annotate_quant_attrs import AnnotateQuantAttrs
89
from .annotate_stack import AnnotateStack
910
from .annotate_unbind import AnnotateUnbind
@@ -38,6 +39,7 @@
3839

3940

4041
__all__ = [
42+
AnnotateAdaptiveAvgPool1D,
4143
AnnotateQuantAttrs,
4244
AnnotateStack,
4345
AnnotateUnbind,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
8+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
11+
12+
from .utils import get_quant_attrs
13+
14+
15+
class AnnotateAdaptiveAvgPool1D(ExportPass):
16+
"""
17+
Add "quant_attrs" to graph nodes' meta from the QDQ information
18+
generated after quantization process.
19+
adaptive_avg_pool1d got decomposed to unsqueeze -> adaptive_avg_pool2d -> squeeze
20+
"""
21+
22+
def __init__(self, edge_program: torch.export.ExportedProgram):
23+
super(AnnotateAdaptiveAvgPool1D, self).__init__()
24+
self.edge_program = edge_program
25+
26+
def _annotate_adaptive_avg_pool1d(self, graph_module: torch.fx.GraphModule):
27+
partitions = get_source_partitions(
28+
graph_module.graph, [torch.ops.aten.adaptive_avg_pool1d.default]
29+
)
30+
for src_partitions in partitions.values():
31+
for src_partition in src_partitions:
32+
output = src_partition.output_nodes[0]
33+
if (list(output.users)[0].target) in q_ops:
34+
quant_attrs = get_quant_attrs(
35+
self.edge_program, list(output.users)[0]
36+
)
37+
for n in src_partition.nodes:
38+
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
39+
40+
def call(self, graph_module: torch.fx.GraphModule):
41+
self._annotate_adaptive_avg_pool1d(graph_module)
42+
graph_module.recompile()
43+
return PassResult(graph_module, True)

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch
10+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
1011
from executorch.backends.qualcomm.builders.utils import get_parameter
1112
from executorch.backends.qualcomm.utils.constants import (
1213
QCOM_DTYPE,
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.exir.pass_base import ExportPass, PassResult
2223

23-
from .utils import dq_ops, get_quant_attrs, q_ops
24+
from .utils import get_quant_attrs
2425

2526

2627
class AnnotateQuantAttrs(ExportPass):

backends/qualcomm/_passes/annotate_stack.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
78
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1011

11-
from .utils import get_quant_attrs, q_ops
12+
from .utils import get_quant_attrs
1213

1314

1415
class AnnotateStack(ExportPass):
@@ -27,7 +28,7 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule):
2728
partitions = get_source_partitions(
2829
graph_module.graph, [torch.stack, torch.ops.aten.stack.default, "stack"]
2930
)
30-
for _, src_partitions in partitions.items():
31+
for src_partitions in partitions.values():
3132
for src_partition in src_partitions:
3233
output = src_partition.output_nodes[0]
3334
if (list(output.users)[0].target) in q_ops:

backends/qualcomm/_passes/annotate_unbind.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
79
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
810
from executorch.exir.pass_base import ExportPass, PassResult
911
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1012

11-
from .utils import dq_ops, get_quant_attrs
13+
from .utils import get_quant_attrs
1214

1315

1416
class AnnotateUnbind(ExportPass):
@@ -27,7 +29,7 @@ def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
2729
partitions = get_source_partitions(
2830
graph_module.graph, [torch.unbind, torch.ops.aten.unbind.int, "unbind"]
2931
)
30-
for _, src_partitions in partitions.items():
32+
for src_partitions in partitions.values():
3133
for src_partition in src_partitions:
3234
if src_partition.input_nodes[0].target in dq_ops:
3335
q_node = src_partition.input_nodes[0].args[0]

backends/qualcomm/_passes/expand_broadcast_tensor_shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
9+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass, PassResult
1012
from executorch.exir.passes import dead_code_elimination_pass
1113

12-
from .utils import dq_ops
13-
1414

1515
class ExpandBroadcastTensorShape(ExportPass):
1616
"""

backends/qualcomm/_passes/fold_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
79
from executorch.backends.qualcomm.builders.utils import is_parameter
810
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass, PassResult
1113
from executorch.exir.passes import dead_code_elimination_pass
1214

13-
from .utils import dq_ops, q_ops
14-
1515

1616
class FoldQDQ(ExportPass):
1717
"""

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from executorch.backends.qualcomm.builders.node_visitor import q_ops
11+
1012
from executorch.backends.qualcomm.builders.utils import is_parameter
1113
from executorch.backends.qualcomm.utils.constants import (
1214
QCOM_ENCODING,
@@ -16,8 +18,6 @@
1618
from executorch.exir.dialects._ops import ops as exir_ops
1719
from executorch.exir.pass_base import ExportPass, PassResult
1820

19-
from .utils import q_ops
20-
2121

2222
class InsertIOQDQ(ExportPass):
2323
"""

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict
1010

1111
from executorch.backends.qualcomm._passes import (
12+
AnnotateAdaptiveAvgPool1D,
1213
AnnotateQuantAttrs,
1314
AnnotateStack,
1415
AnnotateUnbind,
@@ -73,6 +74,7 @@ def get_capture_program_passes():
7374
# The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
7475
# If a pass is activated, it will be executed by default.
7576
default_passes_and_setting = [
77+
(AnnotateAdaptiveAvgPool1D, True),
7678
(AnnotateQuantAttrs, True),
7779
(AnnotateStack, True),
7880
(AnnotateUnbind, True),
@@ -128,11 +130,11 @@ def get_to_edge_transform_passes(
128130
dep_table: Dict = None,
129131
):
130132
# TODO: remove this workaround when target could be correctly detected
131-
from executorch.backends.qualcomm._passes import utils
133+
from executorch.backends.qualcomm.builders import node_visitor
132134
from executorch.exir.dialects._ops import ops as exir_ops
133135

134-
utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
135-
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
136+
node_visitor.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
137+
node_visitor.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
136138

137139
passes_job = (
138140
passes_job if passes_job is not None else get_capture_program_passes()

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
8+
from executorch.backends.qualcomm.builders.node_visitor import dq_ops
79
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass, PassResult
1012
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1113

12-
from .utils import dq_ops
13-
1414

1515
class RecomposeRmsNorm(ExportPass):
1616
"""

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,6 @@
1313
from torch._subclasses import FakeTensor
1414

1515

16-
q_ops = {
17-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
18-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
19-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
20-
}
21-
22-
dq_ops = {
23-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
24-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
25-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
26-
}
27-
28-
2916
def copy_meta(meta: Dict, callback=None):
3017
copied = {}
3118
for k, v in meta.items():
@@ -73,6 +60,7 @@ def get_passes_dependency_for_capture_program():
7360
dict: A dictionary mapping each pass to its corresponding list of dependencies.
7461
"""
7562
from executorch.backends.qualcomm._passes import (
63+
AnnotateAdaptiveAvgPool1D,
7664
AnnotateQuantAttrs,
7765
AnnotateStack,
7866
AnnotateUnbind,
@@ -94,6 +82,7 @@ def get_passes_dependency_for_capture_program():
9482
)
9583

9684
return {
85+
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
9786
AnnotateQuantAttrs: [
9887
RecomposePixelUnshuffle,
9988
ConvertBmmToMatmul,

backends/qualcomm/builders/node_visitor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import numpy as np
1313
import torch
14-
from executorch.backends.qualcomm._passes.utils import dq_ops
1514
from executorch.backends.qualcomm.utils.constants import (
1615
QCOM_AXIS,
1716
QCOM_AXIS_ORDER,
@@ -79,6 +78,18 @@
7978
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
8079
}
8180

81+
q_ops = {
82+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
83+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
84+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
85+
}
86+
87+
dq_ops = {
88+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
89+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
90+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
91+
}
92+
8293

8394
class NodeVisitor:
8495
"""

backends/qualcomm/quantizer/annotators.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,13 @@ def annotate_neg(node: Node, quantization_config: QuantizationConfig) -> None:
462462
annotate_single_in_single_out(node, quantization_config)
463463

464464

465-
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
466-
def annotate_adaptive_avgpool2d(
465+
@register_annotator(
466+
[
467+
torch.ops.aten.adaptive_avg_pool1d.default,
468+
torch.ops.aten.adaptive_avg_pool2d.default,
469+
]
470+
)
471+
def annotate_adaptive_avg_pool(
467472
node: Node, quantization_config: QuantizationConfig
468473
) -> None:
469474
annotate_single_in_single_out(node, quantization_config)
@@ -1170,7 +1175,13 @@ def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None
11701175
)
11711176

11721177

1173-
@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default])
1178+
@register_annotator(
1179+
[
1180+
torch.ops.aten.split_with_sizes.default,
1181+
torch.ops.aten.split.Tensor,
1182+
torch.ops.aten.chunk.default,
1183+
]
1184+
)
11741185
def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
11751186
if _is_annotated([node]):
11761187
return

backends/qualcomm/scripts/build.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ CMAKE_X86_64="build-x86"
3030
BUILD_AARCH64="true"
3131
CMAKE_AARCH64="build-android"
3232
CLEAN="true"
33-
BUILD_TYPE="Debug"
33+
BUILD_TYPE="RelWithDebInfo"
3434
BUILD_JOB_NUMBER="16"
3535

3636
if [ -z PYTHON_EXECUTABLE ]; then
@@ -71,7 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then
7171
rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT
7272
else
7373
# Force rebuild flatccrt for the correct platform
74-
cd $BUILD_ROOT/devtools && make clean
74+
cd $BUILD_ROOT/third-party/flatcc && make clean
7575
fi
7676

7777
cd $BUILD_ROOT
@@ -116,7 +116,7 @@ if [ "$BUILD_X86_64" = true ]; then
116116
rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT
117117
else
118118
# Force rebuild flatccrt for the correct platform
119-
cd $BUILD_ROOT/devtools && make clean
119+
cd $BUILD_ROOT/third-party/flatcc && make clean
120120
fi
121121

122122
cd $BUILD_ROOT

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3868,6 +3868,44 @@ def test_fbnet(self):
38683868
self.assertGreaterEqual(msg["top_1"], 60)
38693869
self.assertGreaterEqual(msg["top_5"], 90)
38703870

3871+
def test_focalnet(self):
3872+
if not self.required_envs([self.image_dataset]):
3873+
self.skipTest("missing required envs")
3874+
3875+
cmds = [
3876+
"python",
3877+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/focalnet.py",
3878+
"--dataset",
3879+
self.image_dataset,
3880+
"--artifact",
3881+
self.artifact_dir,
3882+
"--build_folder",
3883+
self.build_folder,
3884+
"--device",
3885+
self.device,
3886+
"--model",
3887+
self.model,
3888+
"--ip",
3889+
self.ip,
3890+
"--port",
3891+
str(self.port),
3892+
]
3893+
if self.host:
3894+
cmds.extend(["--host", self.host])
3895+
if self.shared_buffer:
3896+
cmds.extend(["--shared_buffer"])
3897+
3898+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3899+
with Listener((self.ip, self.port)) as listener:
3900+
conn = listener.accept()
3901+
p.communicate()
3902+
msg = json.loads(conn.recv())
3903+
if "Error" in msg:
3904+
self.fail(msg["Error"])
3905+
else:
3906+
self.assertGreaterEqual(msg["top_1"], 55)
3907+
self.assertGreaterEqual(msg["top_5"], 80)
3908+
38713909
def test_gMLP(self):
38723910
if not self.required_envs([self.image_dataset]):
38733911
self.skipTest("missing required envs")

0 commit comments

Comments
 (0)