|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import fx |
| 5 | +from torch._subclasses import FakeTensor |
| 6 | +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix |
| 7 | + |
| 8 | +COMPARE_SCALAR_OPS = { |
| 9 | + torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor, |
| 10 | + torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor, |
| 11 | + torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor, |
| 12 | + torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor, |
| 13 | + torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, |
| 14 | + torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor, |
| 15 | +} |
| 16 | + |
| 17 | + |
| 18 | +def _not_float_tensor(node: fx.Node) -> bool: |
| 19 | + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): |
| 20 | + return True |
| 21 | + return node.meta["val"].dtype != torch.float32 |
| 22 | + |
| 23 | + |
| 24 | +def _not_bool_tensor(node: fx.Node) -> bool: |
| 25 | + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): |
| 26 | + return True |
| 27 | + return node.meta["val"].dtype != torch.bool |
| 28 | + |
| 29 | + |
| 30 | +def lift_constant_scalar_operands(gm: torch.fx.GraphModule) -> None: |
| 31 | + # If the node is add(input, constant) and constant is a scalar, we can lift the constant |
| 32 | + # and the annotation for quantization will insert q/dq nodes around the lifted constant. |
| 33 | + for n in gm.graph.nodes: |
| 34 | + if n.op != "call_function" or n.target not in ( |
| 35 | + torch.ops.aten.add.Tensor, |
| 36 | + torch.ops.aten.sub.Tensor, |
| 37 | + torch.ops.aten.mul.Tensor, |
| 38 | + torch.ops.aten.div.Tensor, |
| 39 | + torch.ops.aten.rsub.Scalar, |
| 40 | + torch.ops.aten.add_.Scalar, |
| 41 | + ) + tuple(COMPARE_SCALAR_OPS.keys()): |
| 42 | + continue |
| 43 | + const_arg = None |
| 44 | + non_const_arg = None |
| 45 | + for arg in n.args: |
| 46 | + if isinstance(arg, torch.fx.Node): |
| 47 | + non_const_arg = arg |
| 48 | + else: |
| 49 | + const_arg = arg |
| 50 | + if non_const_arg is None or const_arg is None: |
| 51 | + continue |
| 52 | + |
| 53 | + if _not_float_tensor(n) and _not_bool_tensor(n): |
| 54 | + continue |
| 55 | + |
| 56 | + if not _not_float_tensor(n): |
| 57 | + tensor_constant = torch.tensor( |
| 58 | + [const_arg], |
| 59 | + dtype=n.meta["val"].dtype, |
| 60 | + device=n.meta["val"].device, |
| 61 | + ) |
| 62 | + else: |
| 63 | + tensor_constant = torch.tensor( |
| 64 | + [const_arg], |
| 65 | + dtype=n.args[0].meta["val"].dtype, |
| 66 | + device=n.meta["val"].device, |
| 67 | + ) |
| 68 | + tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(gm) |
| 69 | + gm.register_buffer(tensor_constant_name, tensor_constant) |
| 70 | + |
| 71 | + fake_mode = n.meta["val"].fake_mode |
| 72 | + with gm.graph.inserting_before(n): |
| 73 | + get_attr_node = gm.graph.get_attr(tensor_constant_name) |
| 74 | + get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant) |
| 75 | + |
| 76 | + if n.target == torch.ops.aten.rsub.Scalar: |
| 77 | + n.args = (get_attr_node, non_const_arg) + n.args[2:] |
| 78 | + n.target = torch.ops.aten.sub.Tensor |
| 79 | + else: |
| 80 | + n.args = (non_const_arg, get_attr_node) + n.args[2:] |
| 81 | + |
| 82 | + if n.target == torch.ops.aten.add_.Scalar: |
| 83 | + n.args = (non_const_arg, get_attr_node) + n.args[2:] |
| 84 | + n.target = torch.ops.aten.add.Tensor |
| 85 | + |
| 86 | + if n.target in tuple(COMPARE_SCALAR_OPS.keys()): |
| 87 | + n.args = (non_const_arg, get_attr_node) + n.args[2:] |
| 88 | + n.target = COMPARE_SCALAR_OPS[n.target] |
| 89 | + |
| 90 | + gm.recompile() |
0 commit comments