Skip to content

Commit f707590

Browse files
Joey Tsaifacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - LLAMA2 Infrastructure (#2020)
Summary: 1. OPs - Add pow_tensor_scalar op - Add rsqrt op - Add sigmoid op - Refine axis handling of cat op - Refine parameters related functions 2. Passes - Add AnnotateDecomposed for unbind and stak op - Add DecomposeSilu for quantizer - Add ReplaceInfBuffer for quantizer - Change pass name ConvertAddmmmmWithLinear to ConvertToLinear - Change pass name ConvertScaledDotProductAttention to DecomposeScaledDotProductAttention - Support more args for sdpa op in DecomposeScaledDotProductAttention - Support mm case for ConvertToLinear - Move q_ops and dq_ops to pass/utils.py 3. Tests - Add dummy llama2 test script - Add single op test cases 4. Others - Fix error of popping missing buffer - Reorder the order of test models - Reorder the order of op in qnn_constant Pull Request resolved: #2020 Reviewed By: kirklandsign Differential Revision: D54010593 Pulled By: cccclai fbshipit-source-id: 657994dc223cb9bd88a263bfc2479295384fcb4d
1 parent 8fed60b commit f707590

28 files changed

+1163
-479
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@
3030
op_mean_dim,
3131
op_mul,
3232
op_pad,
33+
op_pow,
3334
op_quantize,
3435
op_relu,
3536
op_reshape,
37+
op_rsqrt,
3638
op_select_copy,
39+
op_sigmoid,
3740
op_skip_ops,
3841
op_slice_copy,
3942
op_softmax,
@@ -70,10 +73,13 @@
7073
op_mean_dim,
7174
op_mul,
7275
op_pad,
76+
op_pow,
7377
op_quantize,
7478
op_relu,
7579
op_reshape,
80+
op_rsqrt,
7681
op_select_copy,
82+
op_sigmoid,
7783
op_skip_ops,
7884
op_slice_copy,
7985
op_softmax,

backends/qualcomm/builders/op_cat.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@ def define_node(
5454
nodes_to_wrappers,
5555
)
5656

57-
axis = cast(int, node.args[1])
57+
# node args[1] might not exist
58+
axis = 0
59+
if len(node.args) == 2:
60+
axis = cast(int, node.args[1])
61+
62+
if axis < 0:
63+
axis += node.meta["val"].dim()
64+
5865
if "axis_order" in node.meta:
5966
axis = node.meta["axis_order"].index(axis)
6067

backends/qualcomm/builders/op_pow.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpElementWisePower, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
# TODO Add more class Like PowTensorTensor if needed
18+
@register_node_visitor
19+
class PowTensorScalar(NodeVisitor):
20+
target = "aten.pow.Tensor_Scalar"
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29+
) -> PyQnnWrapper.PyQnnOpWrapper:
30+
out_tensor = self.get_tensor(node, node)
31+
output_tensor_wrapper = self.define_tensor(
32+
node,
33+
out_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
pow_output_tensors = [output_tensor_wrapper]
38+
39+
# tensor input
40+
input_node = node.args[0]
41+
input_tensor = self.get_tensor(input_node, node)
42+
43+
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
44+
45+
input_tensor_wrapper = self.define_tensor(
46+
input_node,
47+
input_tensor,
48+
tensor_type,
49+
nodes_to_wrappers,
50+
)
51+
52+
# scalar input
53+
scalar = node.args[1]
54+
scalar_tensor = torch.full(input_tensor.size(), scalar).to(torch.float32)
55+
56+
# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
57+
scalar_node = torch.fx.Node(
58+
node.graph,
59+
node.name + "_runtime_scalar",
60+
"call_function",
61+
exir_ops.edge.aten.full.default,
62+
(), # args
63+
{}, # kwargs
64+
)
65+
66+
if pow_quant_attrs := node.meta.get("quant_attrs"):
67+
quant_attrs = pow_quant_attrs.copy()
68+
quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"]
69+
quant_attrs["zero_point"] = 0 if scalar >= 0 else quant_attrs["quant_max"]
70+
quant_attrs["scale"] = (
71+
scalar / quant_range if scalar >= 0 else -scalar / quant_range
72+
)
73+
scalar_node.meta["quant_attrs"] = quant_attrs
74+
75+
scalar_tensor_wrapper = self.define_tensor(
76+
scalar_node,
77+
scalar_tensor,
78+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
79+
nodes_to_wrappers,
80+
)
81+
82+
pow_input_tensors = [input_tensor_wrapper, scalar_tensor_wrapper]
83+
84+
pow_op = PyQnnWrapper.PyQnnOpWrapper(
85+
node.name,
86+
QNN_OP_PACKAGE_NAME_QTI_AISW,
87+
OpElementWisePower.op_name,
88+
)
89+
pow_op.AddInputTensors(pow_input_tensors)
90+
pow_op.AddOutputTensors(pow_output_tensors)
91+
92+
return pow_op
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpElementWiseRsqrt, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Rsqrt(NodeVisitor):
18+
target = "aten.rsqrt.default"
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = node.args[0]
29+
input_tensor = self.get_tensor(input_node, node)
30+
rsqrt_inp_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
input_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
rsqrt_input_tensors = [rsqrt_inp_tensor_wrapper]
37+
38+
output_tensor = self.get_tensor(node, node)
39+
output_tensor_wrapper = self.define_tensor(
40+
node,
41+
output_tensor,
42+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43+
nodes_to_wrappers,
44+
)
45+
rsqrt_output_tensors = [output_tensor_wrapper]
46+
47+
rsqrt_op = PyQnnWrapper.PyQnnOpWrapper(
48+
node.name,
49+
QNN_OP_PACKAGE_NAME_QTI_AISW,
50+
OpElementWiseRsqrt.op_name,
51+
)
52+
rsqrt_op.AddInputTensors(rsqrt_input_tensors)
53+
rsqrt_op.AddOutputTensors(rsqrt_output_tensors)
54+
55+
return rsqrt_op
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpSigmoid, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Sigmoid(NodeVisitor):
18+
target = "aten.sigmoid.default"
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = node.args[0]
29+
input_tensor = self.get_tensor(input_node, node)
30+
sigmoid_inp_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
input_tensor,
33+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
34+
nodes_to_wrappers,
35+
)
36+
sigmoid_input_tensors = [sigmoid_inp_tensor_wrapper]
37+
38+
output_tensor = self.get_tensor(node, node)
39+
output_tensor_wrapper = self.define_tensor(
40+
node,
41+
output_tensor,
42+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
43+
nodes_to_wrappers,
44+
)
45+
sigmoid_output_tensors = [output_tensor_wrapper]
46+
47+
sigmoid_op = PyQnnWrapper.PyQnnOpWrapper(
48+
node.name,
49+
QNN_OP_PACKAGE_NAME_QTI_AISW,
50+
OpSigmoid.op_name,
51+
)
52+
sigmoid_op.AddInputTensors(sigmoid_input_tensors)
53+
sigmoid_op.AddOutputTensors(sigmoid_output_tensors)
54+
55+
return sigmoid_op

0 commit comments

Comments
 (0)