Skip to content

Commit 59df9ef

Browse files
committed
Qualcomm AI Engine Direct - Mimi Enablement Stage 1
1 parent 90f0843 commit 59df9ef

27 files changed

+947
-235
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from .annotate_decomposed import AnnotateDecomposed
28
from .annotate_quant_attrs import AnnotateQuantAttrs
39
from .constant_i64_to_i32 import ConstantI64toI32
410
from .convert_bmm_to_matmul import ConvertBmmToMatmul
11+
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
512
from .convert_to_linear import ConvertToLinear
613
from .decompose_any import DecomposeAny
714
from .decompose_einsum import DecomposeEinsum
15+
from .decompose_expm1 import DecomposeExpM1
816
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
917
from .decompose_silu import DecomposeSilu
1018
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
@@ -19,8 +27,9 @@
1927
from .recompose_rms_norm import RecomposeRmsNorm
2028
from .reduce_dynamic_range import ReduceDynamicRange
2129
from .remove_redundancy import RemoveRedundancy
30+
from .replace_arange_args import ReplaceArangeArgs
2231
from .replace_index_put_input import ReplaceIndexPutInput
23-
from .replace_inf_buffer import ReplaceInfBuffer
32+
from .replace_inf_values import ReplaceInfValues
2433
from .tensor_i64_to_i32 import TensorI64toI32
2534

2635

@@ -29,10 +38,12 @@
2938
AnnotateQuantAttrs,
3039
ConstantI64toI32,
3140
ConvertBmmToMatmul,
41+
ConvertConv1dToConv2d,
3242
RecomposePReLU,
3343
ConvertToLinear,
3444
DecomposeAny,
3545
DecomposeEinsum,
46+
DecomposeExpM1,
3647
DecomposeLinalgVectorNorm,
3748
DecomposeSilu,
3849
ExpandBroadcastTensorShape,
@@ -46,7 +57,8 @@
4657
RecomposeRmsNorm,
4758
ReduceDynamicRange,
4859
RemoveRedundancy,
60+
ReplaceArangeArgs,
4961
ReplaceIndexPutInput,
50-
ReplaceInfBuffer,
62+
ReplaceInfValues,
5163
TensorI64toI32,
5264
]
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import torch.nn as nn
9+
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
from .utils import copy_meta
14+
15+
16+
class ConvertConv1dToConv2d(ExportPass):
17+
"""
18+
Conv1d is not supported by QNN.
19+
Change it to input -> unsqueeze -> conv2d -> squeeze -> output
20+
"""
21+
22+
def __init__(self, edge_program: torch.export.ExportedProgram):
23+
super(ConvertConv1dToConv2d, self).__init__()
24+
self.edge_program = edge_program
25+
26+
def call(self, graph_module: torch.fx.GraphModule):
27+
graph = graph_module.graph
28+
conv_op = exir_ops.edge.aten.convolution.default
29+
for node in graph.nodes:
30+
if node.target == conv_op and node.meta["val"].dim() == 3:
31+
32+
input_node = node.args[0]
33+
with graph_module.graph.inserting_after(input_node):
34+
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
35+
unsqueeze_node = graph.create_node(
36+
"call_function",
37+
unsqueeze_op,
38+
(
39+
input_node,
40+
2,
41+
),
42+
)
43+
unsqueeze_node.meta = copy_meta(
44+
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
45+
)
46+
with graph_module.graph.inserting_after(unsqueeze_node):
47+
48+
filter_node = node.args[1]
49+
filter_node.meta["val"] = (
50+
filter_node.meta["val"].unsqueeze(2).contiguous()
51+
)
52+
filter_tensor = get_parameter(filter_node, self.edge_program)
53+
# Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
54+
filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
55+
set_parameter(filter_tensor, filter_node, self.edge_program)
56+
57+
bias_node = node.args[2]
58+
stride = [1] + node.args[3]
59+
padding = [0] + node.args[4]
60+
dilation = [1] + node.args[5]
61+
transpose = node.args[6]
62+
output_padding = [0] + node.args[7]
63+
groups = node.args[8]
64+
65+
conv2d_node = graph.create_node(
66+
"call_function",
67+
conv_op,
68+
(
69+
unsqueeze_node,
70+
filter_node,
71+
bias_node,
72+
stride,
73+
padding,
74+
dilation,
75+
transpose,
76+
output_padding,
77+
groups,
78+
),
79+
)
80+
conv2d_node.meta = copy_meta(
81+
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
82+
)
83+
84+
with graph_module.graph.inserting_after(conv2d_node):
85+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
86+
squeeze_node = graph.create_node(
87+
"call_function",
88+
squeeze_op,
89+
(
90+
conv2d_node,
91+
[2],
92+
),
93+
)
94+
squeeze_node.meta = copy_meta(node.meta)
95+
for user in node.users.copy():
96+
user.replace_input_with(node, squeeze_node)
97+
graph.eliminate_dead_code()
98+
graph_module.recompile()
99+
return PassResult(graph_module, True)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_meta
11+
12+
13+
class DecomposeExpM1(ExportPass):
14+
"""
15+
Decompose for expm1 to exponential and minus 1.
16+
"""
17+
18+
def __init__(self, quantization_capture=False) -> None:
19+
super().__init__()
20+
21+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22+
graph = graph_module.graph
23+
for node in graph.nodes:
24+
if node.target == torch.ops.aten.special_expm1.default:
25+
input_node = node.args[0]
26+
with graph_module.graph.inserting_after(input_node):
27+
exp_op = torch.ops.aten.exp.default
28+
exp_node = graph.create_node("call_function", exp_op, (input_node,))
29+
exp_node.meta = copy_meta(node.meta)
30+
with graph_module.graph.inserting_after(exp_node):
31+
sub_op = torch.ops.aten.sub.Tensor
32+
sub_node = graph.create_node(
33+
"call_function",
34+
sub_op,
35+
(
36+
exp_node,
37+
1,
38+
),
39+
)
40+
sub_node.meta = copy_meta(node.meta)
41+
for user in node.users.copy():
42+
user.replace_input_with(node, sub_node)
43+
44+
graph.eliminate_dead_code()
45+
graph_module.recompile()
46+
return PassResult(graph_module, True)

backends/qualcomm/_passes/decompose_silu.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,17 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Dict
76

87
import torch
98
from executorch.exir.pass_base import ExportPass, PassResult
109

10+
from .utils import copy_meta
11+
1112

1213
class DecomposeSilu(ExportPass):
1314
def __init__(self):
1415
super(DecomposeSilu, self).__init__()
1516

16-
def _copy_meta(self, meta: Dict):
17-
copied = {}
18-
for k, v in meta.items():
19-
copied[k] = v
20-
return copied
21-
2217
def call(self, graph_module: torch.fx.GraphModule):
2318
graph = graph_module.graph
2419
for node in graph.nodes:
@@ -34,14 +29,14 @@ def call(self, graph_module: torch.fx.GraphModule):
3429
torch.ops.aten.sigmoid.default,
3530
(silu_node_input,),
3631
)
37-
sigmoid_node.meta = self._copy_meta(silu_node.meta)
32+
sigmoid_node.meta = copy_meta(silu_node.meta)
3833
with graph_module.graph.inserting_after(sigmoid_node):
3934
mul_node = graph.create_node(
4035
"call_function",
4136
torch.ops.aten.mul.Tensor,
4237
(silu_node_input, sigmoid_node),
4338
)
44-
mul_node.meta = self._copy_meta(silu_node.meta)
39+
mul_node.meta = copy_meta(silu_node.meta)
4540
for user in silu_node.users.copy():
4641
user.replace_input_with(silu_node, mul_node)
4742

backends/qualcomm/_passes/layout_transform.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ class LayoutTransform(ExportPass):
4949
exir_ops.edge.aten.add.Tensor,
5050
exir_ops.edge.aten.bitwise_or.Tensor,
5151
exir_ops.edge.aten.bmm.default,
52+
exir_ops.edge.aten.bitwise_and.Tensor,
5253
exir_ops.edge.aten.cat.default,
5354
exir_ops.edge.aten.ceil.default,
5455
exir_ops.edge.aten.clamp.default,
5556
exir_ops.edge.aten.constant_pad_nd.default,
5657
exir_ops.edge.aten.div.Tensor,
58+
exir_ops.edge.aten.elu.default,
5759
exir_ops.edge.aten.eq.Tensor,
60+
exir_ops.edge.aten.exp.default,
5861
exir_ops.edge.aten.full.default,
5962
exir_ops.edge.aten.full_like.default,
6063
exir_ops.edge.aten.ge.Tensor,
@@ -87,10 +90,13 @@ class LayoutTransform(ExportPass):
8790
exir_ops.edge.aten.sqrt.default,
8891
exir_ops.edge.aten.sub.Tensor,
8992
exir_ops.edge.aten.sum.dim_IntList,
93+
exir_ops.edge.aten.stack.default,
9094
exir_ops.edge.aten.topk.default,
9195
exir_ops.edge.aten._to_copy.default,
96+
exir_ops.edge.aten.unbind.int,
9297
exir_ops.edge.aten.where.self,
9398
_operator.getitem,
99+
torch.ops.aten.scalar_tensor.default,
94100
}
95101

96102
layout_type = {

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,27 @@ class TensorConstant:
2828
class TensorOpInfo:
2929
target: torch._ops.OpOverload
3030
use_schema_args: bool
31+
use_self_dtype: bool
3132

3233

3334
SCALAR_OPS = {
34-
aten.eq.Scalar: TensorOpInfo(aten.eq.Tensor, False),
35-
aten.ge.Scalar: TensorOpInfo(aten.ge.Tensor, False),
36-
aten.gt.Scalar: TensorOpInfo(aten.gt.Tensor, False),
37-
aten.le.Scalar: TensorOpInfo(aten.le.Tensor, False),
38-
aten.lt.Scalar: TensorOpInfo(aten.lt.Tensor, False),
39-
aten.ne.Scalar: TensorOpInfo(aten.ne.Tensor, False),
40-
aten.add.Scalar: TensorOpInfo(aten.add.Tensor, False),
41-
aten.add_.Scalar: TensorOpInfo(aten.add_.Tensor, False),
42-
aten.div.Scalar: TensorOpInfo(aten.div.Tensor, False),
43-
aten.mul.Scalar: TensorOpInfo(aten.mul.Tensor, False),
44-
aten.rsub.Scalar: TensorOpInfo(aten.rsub.Tensor, False),
45-
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False),
46-
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False),
35+
aten.eq.Scalar: TensorOpInfo(aten.eq.Tensor, False, False),
36+
aten.ge.Scalar: TensorOpInfo(aten.ge.Tensor, False, False),
37+
aten.gt.Scalar: TensorOpInfo(aten.gt.Tensor, False, False),
38+
aten.le.Scalar: TensorOpInfo(aten.le.Tensor, False, False),
39+
aten.lt.Scalar: TensorOpInfo(aten.lt.Tensor, False, False),
40+
aten.ne.Scalar: TensorOpInfo(aten.ne.Tensor, False, False),
41+
aten.add.Scalar: TensorOpInfo(aten.add.Tensor, False, False),
42+
aten.add_.Scalar: TensorOpInfo(aten.add_.Tensor, False, False),
43+
aten.div.Scalar: TensorOpInfo(aten.div.Tensor, False, False),
44+
aten.mul.Scalar: TensorOpInfo(aten.mul.Tensor, False, False),
45+
aten.rsub.Scalar: TensorOpInfo(aten.rsub.Tensor, False, False),
46+
aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False, False),
47+
aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False),
4748
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
48-
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True),
49+
aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False),
50+
aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True),
51+
aten.where.Scalar: TensorOpInfo(aten.where.self, False, True),
4952
}
5053

5154

@@ -63,11 +66,14 @@ def __init__(self):
6366
def _build_tensor_constant(
6467
self, gm: torch.fx.GraphModule, node: fx.Node, const_val
6568
) -> TensorConstant:
69+
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
70+
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
6671
tensor = torch.tensor(
6772
[const_val],
6873
dtype=(
6974
node.args[0].meta["val"].dtype
7075
if not is_float_tensor(node)
76+
and not SCALAR_OPS.get(node.target).use_self_dtype
7177
else node.meta["val"].dtype
7278
),
7379
device=node.meta["val"].device,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_meta
11+
12+
13+
class ReplaceArangeArgs(ExportPass):
14+
"""
15+
During annotation, kwargs for arange will be removed due to restrictions by quantizer.
16+
This causes arange to have no dtype, which means FP nodes might be inferred as INT nodes during calibration.
17+
This can cause calibration to fail since QDQ can only be applied on FP nodes but not INT nodes.
18+
To hint the dtype, we provide step size as 1.0 instead of 1, which makes the node a FP node.
19+
"""
20+
21+
def __init__(self, quantization_capture=False) -> None:
22+
super().__init__()
23+
self.quantization_capture = quantization_capture
24+
25+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
26+
graph = graph_module.graph
27+
for node in graph.nodes:
28+
if node.target == torch.ops.aten.arange.default:
29+
if torch.is_floating_point(node.meta["val"]) and len(node.args) == 1:
30+
with graph_module.graph.inserting_after(node):
31+
step_arange_op = torch.torch.ops.aten.arange.start_step
32+
step_arange_node = graph.create_node(
33+
"call_function",
34+
step_arange_op,
35+
(
36+
0,
37+
node.args[0],
38+
1.0,
39+
),
40+
)
41+
step_arange_node.meta = copy_meta(node.meta)
42+
43+
for user in node.users.copy():
44+
user.replace_input_with(node, step_arange_node)
45+
46+
graph.eliminate_dead_code()
47+
graph_module.recompile()
48+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)