Skip to content

Commit 57e192b

Browse files
haowhsu-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - support embedding op (#2057)
Summary: - support embedding op with int32 index input - make mobilebert / llama2 be fully delegated - add requantize passes for mixed precision - bug fixes Pull Request resolved: #2057 Reviewed By: dbort Differential Revision: D54348816 Pulled By: cccclai fbshipit-source-id: ec3c8e87cc879d6f642859231255d5094d78349f
1 parent 75352ad commit 57e192b

File tree

16 files changed

+279
-30
lines changed

16 files changed

+279
-30
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
op_depth_to_space,
1919
op_dequantize,
2020
op_div,
21+
op_embedding,
2122
op_expand,
2223
op_gelu,
2324
op_hardswish,
@@ -62,6 +63,7 @@
6263
op_depth_to_space,
6364
op_dequantize,
6465
op_div,
66+
op_embedding,
6567
op_expand,
6668
op_gelu,
6769
op_hardswish,

backends/qualcomm/builders/node_visitor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,13 @@ def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]:
9090
{},
9191
)
9292

93-
quant_attrs = node.meta["quant_attrs"]
93+
quant_attrs = (
94+
node.meta["requantize"]["dq_attrs"]
95+
if "requantize" in node.meta
96+
else node.meta["quant_attrs"]
97+
)
9498
encoding = quant_attrs["encoding"]
99+
95100
quant_config = {}
96101
if encoding in PER_CHANNEL_ENCODING_MAPPING:
97102
scales = quant_attrs["scales"]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
from .utils import get_parameter
16+
17+
18+
@register_node_visitor
19+
class Embedding(NodeVisitor):
20+
target = "aten.embedding.default"
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
weight_node = node.args[0]
31+
weight_tensor = get_parameter(weight_node, self.edge_program)
32+
weight_tensor_wrapper = self.define_tensor(
33+
weight_node,
34+
weight_tensor,
35+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
36+
nodes_to_wrappers,
37+
)
38+
39+
indices_node = node.args[1]
40+
indices_tensor = self.get_tensor(indices_node, node)
41+
indices_tensor_wrapper = self.define_scalar(
42+
indices_node,
43+
indices_tensor,
44+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
46+
)
47+
48+
gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper]
49+
50+
output_tensor = self.get_tensor(node, node)
51+
output_tensor_wrapper = self.define_tensor(
52+
node,
53+
output_tensor,
54+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
55+
nodes_to_wrappers,
56+
)
57+
gather_output_tensors = [output_tensor_wrapper]
58+
59+
gather_op = PyQnnWrapper.PyQnnOpWrapper(
60+
node.name,
61+
QNN_OP_PACKAGE_NAME_QTI_AISW,
62+
OpGather.op_name,
63+
)
64+
gather_op.AddInputTensors(gather_input_tensors)
65+
gather_op.AddOutputTensors(gather_output_tensors)
66+
67+
# For now, default axis is zero.
68+
gather_op.AddScalarParam(
69+
OpGather.param_axis,
70+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
71+
{"data": np.int32(0)},
72+
)
73+
74+
return gather_op

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
exir_ops.edge.aten.arange.start_step,
1313
exir_ops.edge.aten.index.Tensor,
1414
exir_ops.edge.aten.full.default,
15-
exir_ops.edge.aten.embedding.default,
1615
]
1716

1817
allow_list_operator = [

backends/qualcomm/passes/annotate_and_quant_scalar.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
9898
q_node = dq_node.args[0]
9999
q_node_attrs = get_quant_attrs(graph_module, q_node)
100100

101-
scalar_node = [n for n in output.args if n != dq_node][0]
101+
scalar_nodes = [n for n in output.args if n != dq_node]
102+
if len(scalar_nodes) == 0:
103+
continue
104+
105+
scalar_node = scalar_nodes[0]
102106
source_scalar_node = self._get_source_scalar_node(scalar_node)
103107
# we'll abandon cast op here, since the constant scalar will
104108
# be pre-loaded into QNN context binary

backends/qualcomm/passes/insert_io_qdq.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ def _insert_node(
4949
graph_module: torch.fx.GraphModule,
5050
node: torch.fx.node,
5151
target: torch.fx.node.Target,
52+
quant_attrs: Dict = None,
5253
) -> torch.fx.node:
53-
quant_attrs = node.meta.get("quant_attrs")
54+
# check if there has a specified quant_attrs
55+
# if not, use the existent info. from current node
56+
if quant_attrs is None:
57+
quant_attrs = node.meta.get("quant_attrs")
58+
5459
inserted_node = graph_module.graph.create_node(
5560
"call_function",
5661
target,
@@ -69,13 +74,16 @@ def _insert_quant_node(
6974
graph_module: torch.fx.GraphModule,
7075
node: torch.fx.node,
7176
target: torch.fx.node.Target,
72-
) -> None:
77+
quant_attrs: Dict = None,
78+
) -> torch.fx.Node:
7379
with graph_module.graph.inserting_after(node):
7480
users = list(node.users.keys())
75-
inserted_node = self._insert_node(graph_module, node, target)
81+
inserted_node = self._insert_node(graph_module, node, target, quant_attrs)
7682
for user in users:
7783
user.replace_input_with(node, inserted_node)
7884

85+
return inserted_node
86+
7987
def _insert_dequant_node(
8088
self,
8189
graph_module: torch.fx.GraphModule,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
13+
class InsertRequantize(InsertIOQDQ):
14+
"""
15+
This pass inserts dq/q nodes for non-arithmetic operators which have
16+
different quantization specs in input and activation
17+
"""
18+
19+
def __init__(
20+
self,
21+
edge_program: torch.export.ExportedProgram,
22+
insert_requantize: bool = False,
23+
):
24+
super().__init__(edge_program)
25+
# add non-arithmetic operators here if condition met
26+
self.op_map = {
27+
exir_ops.edge.aten.permute_copy.default: self._single_io_annotation,
28+
}
29+
self.insert_requantize = insert_requantize
30+
31+
def _single_io_annotation(self, gm: torch.fx.GraphModule, n: torch.fx.node) -> None:
32+
in_q_attr = n.args[0].meta.get("quant_attrs")
33+
out_q_attr = n.meta["quant_attrs"]
34+
if in_q_attr is not None and in_q_attr["dtype"] != out_q_attr["dtype"]:
35+
if self.insert_requantize:
36+
dq_attr = n.meta["requantize"]["dq_attrs"]
37+
q_attr = n.meta["requantize"]["q_attrs"]
38+
# insert dq with given quantization attribute in input node
39+
dq = self._insert_quant_node(gm, n, dq_attr["encoding"], dq_attr)
40+
dq.meta["quant_attrs"] = dq_attr
41+
# insert q with given quantization attribute in current node
42+
q = self._insert_quant_node(gm, dq, q_attr["encoding"], q_attr)
43+
q.meta["quant_attrs"] = q_attr
44+
else:
45+
dq_attr = in_q_attr.copy()
46+
dq_attr["encoding"] = self.q_dq_map[out_q_attr["encoding"]]
47+
q_attr = out_q_attr.copy()
48+
n.meta["requantize"] = {"dq_attrs": dq_attr, "q_attrs": q_attr}
49+
50+
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
51+
for n in graph_module.graph.nodes:
52+
if (
53+
n.op == "call_function"
54+
and n.meta.get("quant_attrs")
55+
and n.target in self.op_map
56+
):
57+
self.op_map[n.target](graph_module, n)

backends/qualcomm/passes/layout_transform.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from executorch.exir.pass_base import ExportPass, PassResult
1414
from executorch.exir.sym_util import eval_shape
1515

16+
from .utils import dq_ops, q_ops
17+
1618

1719
class LayoutTransform(ExportPass):
1820
"""
@@ -50,6 +52,8 @@ class LayoutTransform(ExportPass):
5052
exir_ops.edge.aten.bmm.default,
5153
exir_ops.edge.aten.full.default,
5254
exir_ops.edge.aten.gelu.default,
55+
*q_ops,
56+
*dq_ops,
5357
_operator.getitem,
5458
}
5559

@@ -77,6 +81,7 @@ def __init__(
7781
super(LayoutTransform, self).__init__()
7882
self.edge_program = edge_program
7983
self.insert_permute = insert_permute
84+
self.qdq_opset = {*q_ops, *dq_ops}
8085

8186
def mark_as_transformed(self, node: torch.fx.Node) -> None:
8287
if isinstance(node.meta["val"], (tuple, list)):
@@ -108,6 +113,8 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
108113
# if dimemsion is not kept, we'll have no clue how to do layout transform
109114
if len(node.args) < 3 or not node.args[2]:
110115
return False
116+
if node.target in self.qdq_opset:
117+
return "requantize" in node.meta
111118
return node.target in self.layout_agnostic_ops
112119

113120
def is_edge_condition(self, node):
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
10+
11+
12+
class RecomposePixelShuffle(ExportPass):
13+
"""
14+
Merge decomposed operators back to one super node.
15+
"""
16+
17+
def __init__(self):
18+
super().__init__()
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
graph = graph_module.graph
22+
# decomposed core aten ops
23+
partitions = get_source_partitions(graph, [torch.nn.PixelShuffle])
24+
for _, src_partitions in partitions.items():
25+
for src_partition in src_partitions:
26+
input_node = src_partition.input_nodes[0]
27+
output_node = src_partition.output_nodes[0]
28+
with graph.inserting_after(input_node):
29+
h_in_shape = input_node.meta["val"].shape[2]
30+
h_out_shape = output_node.meta["val"].shape[2]
31+
upscale_factor = h_out_shape / h_in_shape
32+
33+
pixel_shuffle_node = graph.create_node(
34+
"call_function",
35+
exir_ops.edge.aten.pixel_shuffle.default,
36+
(input_node, int(upscale_factor)),
37+
)
38+
users = output_node.users.copy()
39+
for user in users:
40+
user.replace_input_with(output_node, pixel_shuffle_node)
41+
# copy metadata
42+
pixel_shuffle_node.meta = output_node.meta
43+
44+
graph.eliminate_dead_code()
45+
graph_module.recompile()
46+
return PassResult(graph_module, True)

backends/qualcomm/qnn_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear
1414
from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ
15+
from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize
1516
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
1617
from executorch.backends.qualcomm.utils.utils import generate_qnn_executorch_option
1718
from executorch.exir.backend.backend_details import (
@@ -44,6 +45,7 @@ def preprocess(
4445
passes=[
4546
ConvertToLinear(),
4647
InsertIOQDQ(edge_program),
48+
InsertRequantize(edge_program, insert_requantize=True),
4749
LayoutTransform(edge_program, insert_permute=True),
4850
]
4951
)

backends/qualcomm/quantizer/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def annotate_single_in_single_out(
104104
input_qspec_map[input_act] = quantization_config.input_activation
105105

106106
node_tensor = node.meta.get("val")
107-
if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
107+
if torch.is_tensor(node_tensor) and node_tensor.dtype != torch.float32:
108108
return
109109

110110
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
@@ -356,6 +356,20 @@ def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> N
356356
annotate_single_in_single_out(node, quantization_config)
357357

358358

359+
@register_annotator([torch.ops.aten.embedding.default])
360+
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
361+
weight = node.args[0]
362+
363+
input_qspec_map = {}
364+
input_qspec_map[weight] = quantization_config.input_activation
365+
366+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
367+
input_qspec_map=input_qspec_map,
368+
output_qspec=SharedQuantizationSpec((weight, node)),
369+
_annotated=True,
370+
)
371+
372+
359373
@register_annotator([torch.ops.aten.expand.default])
360374
def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
361375
annotate_in_out_obs_sharing_op(node, quantization_config)

0 commit comments

Comments
 (0)