Skip to content

Commit b66e850

Browse files
committed
Convert scalar to tensor before quantizer annoate
Differential Revision: [D55946527](https://our.internmc.facebook.com/intern/diff/D55946527/) ghstack-source-id: 221922774 Pull Request resolved: #2958
1 parent 744d49f commit b66e850

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

backends/qualcomm/quantizer/quantizer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from torch import Tensor
2020
from torch._ops import OpOverload
21+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
2122
from torch.ao.quantization.observer import (
2223
HistogramObserver,
2324
MinMaxObserver,
@@ -371,14 +372,59 @@ def set_per_channel_weight_dtype(
371372
def set_per_channel_quant(self, enable: bool) -> None:
372373
self.enable_per_channel_conv_quant = enable
373374

375+
def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None:
376+
"""
377+
For the case like mul(x, 2), convert the the scalr to tensor
378+
"""
379+
for n in gm.graph.nodes:
380+
if n.op != "call_function" or n.target not in (
381+
torch.ops.aten.add.Tensor,
382+
torch.ops.aten.sub.Tensor,
383+
torch.ops.aten.mul.Tensor,
384+
torch.ops.aten.mul.Scalar,
385+
torch.ops.aten.rsub.Scalar,
386+
):
387+
continue
388+
389+
const_arg = None
390+
non_const_arg = None
391+
for arg in n.args:
392+
if isinstance(arg, torch.fx.Node):
393+
non_const_arg = arg
394+
else:
395+
const_arg = arg
396+
397+
if non_const_arg is None or const_arg is None:
398+
continue
399+
400+
# print(" n'args are all constant: ", n)
401+
tensor_constant = torch.tensor([const_arg], dtype=torch.float32)
402+
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(
403+
gm
404+
)
405+
gm.register_buffer(tensor_constant_name, tensor_constant)
406+
407+
fake_mode = n.meta["val"].fake_mode
408+
with gm.graph.inserting_before(n):
409+
get_attr_node = gm.graph.get_attr(tensor_constant_name)
410+
get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant)
411+
412+
if n.target == torch.ops.aten.rsub.Scalar:
413+
n.args = (get_attr_node, non_const_arg) + n.args[2:]
414+
n.target = torch.ops.aten.sub.Tensor
415+
else:
416+
n.args = (non_const_arg, get_attr_node) + n.args[2:]
417+
418+
gm.recompile()
419+
374420
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
375421
model = RemoveClone()(model).graph_module
376422
model = ReduceDynamicRange()(model).graph_module
377423
model = ConvertHardsigmoid(quantization_capture=True)(model).graph_module
378424
model = DecomposeScaledDotProductAttention()(model).graph_module
379425
model = DecomposeSilu()(model).graph_module
380426
model = ReplaceInfBuffer()(model).graph_module
381-
427+
self._lift_constant_scalar_operands(model)
382428
return model
383429

384430
def validate(self, model: GraphModule) -> None:

0 commit comments

Comments
 (0)