Skip to content

Commit 88daa31

Browse files
chunit-quicChun-I Tsai
authored andcommitted
Lift scalar arguments
- Add pass to lift before quantizer quantize - Add preprocess function for fp model - Delete annotate and quant scalar pass - Delete binary with scalar pass - Delete codeblocks of op builder for scalar - Mark CI expected failure cases
1 parent 75d4abc commit 88daa31

24 files changed

+320
-546
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
21
from .annotate_decomposed import AnnotateDecomposed
32
from .annotate_quant_attrs import AnnotateQuantAttrs
43
from .constant_i64_to_i32 import ConstantI64toI32
5-
from .convert_binary_op_with_scalar import ConvertBinaryOpsWithScalar
64
from .convert_bmm_to_matmul import ConvertBmmToMatmul
75
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
8-
from .convert_prelu import ConvertPReLU
96
from .convert_to_linear import ConvertToLinear
107
from .decompose_any import DecomposeAny
118
from .decompose_einsum import DecomposeEinsum
@@ -17,7 +14,9 @@
1714
from .insert_io_qdq import InsertIOQDQ
1815
from .insert_requantize import InsertRequantize
1916
from .layout_transform import LayoutTransform
17+
from .lift_constant_scalar_operands import LiftConstantScalarOperands
2018
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
19+
from .recompose_prelu import RecomposePReLU
2120
from .recompose_rms_norm import RecomposeRmsNorm
2221
from .reduce_dynamic_range import ReduceDynamicRange
2322
from .remove_redundancy import RemoveRedundancy
@@ -27,14 +26,12 @@
2726

2827

2928
__all__ = [
30-
AnnotateAndQuantScalar,
3129
AnnotateDecomposed,
3230
AnnotateQuantAttrs,
3331
ConstantI64toI32,
3432
ConvertBmmToMatmul,
35-
ConvertBinaryOpsWithScalar,
3633
ConvertInterpolateWithUpsample2D,
37-
ConvertPReLU,
34+
RecomposePReLU,
3835
ConvertToLinear,
3936
DecomposeAny,
4037
DecomposeEinsum,
@@ -46,6 +43,7 @@
4643
InsertIOQDQ,
4744
InsertRequantize,
4845
LayoutTransform,
46+
LiftConstantScalarOperands,
4947
RecomposePixelUnshuffle,
5048
RecomposeRmsNorm,
5149
ReduceDynamicRange,

backends/qualcomm/_passes/annotate_and_quant_scalar.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

backends/qualcomm/_passes/convert_binary_op_with_scalar.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ class DecomposeLinalgVectorNorm(ExportPass):
3232
Decompose for math equivalent op.
3333
"""
3434

35-
def __init__(self, quantization_capture=False) -> None:
35+
def __init__(self, aten_dialect_capture=False) -> None:
3636
super().__init__()
37-
self.quantization_capture = quantization_capture
37+
self.aten_dialect_capture = aten_dialect_capture
3838

3939
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4040
graph = graph_module.graph
@@ -44,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4444
dim = node.args[2] if len(node.args) > 2 else None
4545
keepdim = node.args[3] if len(node.args) > 3 else False
4646
model = LinalgVectorNorm(ord, dim, keepdim)
47-
if self.quantization_capture:
47+
if self.aten_dialect_capture:
4848
decomposed_module = torch.export.export(
4949
model, (node.args[0].meta["val"],)
5050
).module()

backends/qualcomm/_passes/decompose_silu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ def call(self, graph_module: torch.fx.GraphModule):
3030
silu_node_input = node.args[0]
3131
with graph_module.graph.inserting_after(silu_node_input):
3232
sigmoid_node = graph.create_node(
33-
"call_function", torch.ops.aten.sigmoid, (silu_node_input,)
33+
"call_function",
34+
torch.ops.aten.sigmoid.default,
35+
(silu_node_input,),
3436
)
3537
sigmoid_node.meta = self._copy_meta(silu_node.meta)
3638
with graph_module.graph.inserting_after(sigmoid_node):
3739
mul_node = graph.create_node(
3840
"call_function",
39-
torch.ops.aten.mul,
41+
torch.ops.aten.mul.Tensor,
4042
(silu_node_input, sigmoid_node),
4143
)
4244
mul_node.meta = self._copy_meta(silu_node.meta)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,15 @@ class LayoutTransform(ExportPass):
5353
exir_ops.edge.aten.clamp.default,
5454
exir_ops.edge.aten.constant_pad_nd.default,
5555
exir_ops.edge.aten.div.Tensor,
56-
exir_ops.edge.aten.eq.Scalar,
5756
exir_ops.edge.aten.eq.Tensor,
5857
exir_ops.edge.aten.full.default,
5958
exir_ops.edge.aten.full_like.default,
60-
exir_ops.edge.aten.ge.Scalar,
6159
exir_ops.edge.aten.ge.Tensor,
6260
exir_ops.edge.aten.gelu.default,
63-
exir_ops.edge.aten.gt.Scalar,
6461
exir_ops.edge.aten.gt.Tensor,
6562
exir_ops.edge.aten.hardswish.default,
6663
exir_ops.edge.aten.hardsigmoid.default,
6764
exir_ops.edge.aten.hardtanh.default,
68-
exir_ops.edge.aten.leaky_relu.default,
69-
exir_ops.edge.aten.le.Scalar,
7065
exir_ops.edge.aten.le.Tensor,
7166
exir_ops.edge.aten.linear.default,
7267
exir_ops.edge.aten.log.default,

0 commit comments

Comments
 (0)