Skip to content

Commit 71eb924

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add lift scalar to constant tensor pass
Differential Revision: D69318149
1 parent 77f18b2 commit 71eb924

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
import torch
4+
from torch import fx
5+
from torch._subclasses import FakeTensor
6+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
7+
8+
COMPARE_SCALAR_OPS = {
9+
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
10+
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
11+
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
12+
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
13+
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
14+
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
15+
}
16+
17+
18+
def _not_float_tensor(node: fx.Node) -> bool:
19+
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
20+
return True
21+
return node.meta["val"].dtype != torch.float32
22+
23+
24+
def _not_bool_tensor(node: fx.Node) -> bool:
25+
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
26+
return True
27+
return node.meta["val"].dtype != torch.bool
28+
29+
30+
def lift_constant_scalar_operands(gm: torch.fx.GraphModule) -> None:
31+
# If the node is add(input, constant) and constant is a scalar, we can lift the constant
32+
# and the annotation for quantization will insert q/dq nodes around the lifted constant.
33+
for n in gm.graph.nodes:
34+
if n.op != "call_function" or n.target not in (
35+
torch.ops.aten.add.Tensor,
36+
torch.ops.aten.sub.Tensor,
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.div.Tensor,
39+
torch.ops.aten.rsub.Scalar,
40+
torch.ops.aten.add_.Scalar,
41+
) + tuple(COMPARE_SCALAR_OPS.keys()):
42+
continue
43+
const_arg = None
44+
non_const_arg = None
45+
for arg in n.args:
46+
if isinstance(arg, torch.fx.Node):
47+
non_const_arg = arg
48+
else:
49+
const_arg = arg
50+
if non_const_arg is None or const_arg is None:
51+
continue
52+
53+
if _not_float_tensor(n) and _not_bool_tensor(n):
54+
continue
55+
56+
if not _not_float_tensor(n):
57+
tensor_constant = torch.tensor(
58+
[const_arg],
59+
dtype=n.meta["val"].dtype,
60+
device=n.meta["val"].device,
61+
)
62+
else:
63+
tensor_constant = torch.tensor(
64+
[const_arg],
65+
dtype=n.args[0].meta["val"].dtype,
66+
device=n.meta["val"].device,
67+
)
68+
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(gm)
69+
gm.register_buffer(tensor_constant_name, tensor_constant)
70+
71+
fake_mode = n.meta["val"].fake_mode
72+
with gm.graph.inserting_before(n):
73+
get_attr_node = gm.graph.get_attr(tensor_constant_name)
74+
get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant)
75+
76+
if n.target == torch.ops.aten.rsub.Scalar:
77+
n.args = (get_attr_node, non_const_arg) + n.args[2:]
78+
n.target = torch.ops.aten.sub.Tensor
79+
else:
80+
n.args = (non_const_arg, get_attr_node) + n.args[2:]
81+
82+
if n.target == torch.ops.aten.add_.Scalar:
83+
n.args = (non_const_arg, get_attr_node) + n.args[2:]
84+
n.target = torch.ops.aten.add.Tensor
85+
86+
if n.target in tuple(COMPARE_SCALAR_OPS.keys()):
87+
n.args = (non_const_arg, get_attr_node) + n.args[2:]
88+
n.target = COMPARE_SCALAR_OPS[n.target]
89+
90+
gm.recompile()

backends/qualcomm/quantizer/quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import torch
1111
from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum
1212
from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
13+
from executorch.backends.qualcomm._passes.lift_constant_scalar_operands import (
14+
lift_constant_scalar_operands,
15+
)
1316
from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
1417
RecomposePixelUnshuffle,
1518
)
@@ -224,6 +227,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
224227
model = DecomposeSilu()(model).graph_module
225228
model = DecomposeEinsum()(model).graph_module
226229
model = ReplaceInfBuffer()(model).graph_module
230+
lift_constant_scalar_operands(model) # Turn scalar into tensor, such that we can annotate it for quantization
227231
return model
228232

229233
def validate(self, model: GraphModule) -> None:

0 commit comments

Comments
 (0)