|
18 | 18 |
|
19 | 19 | from torch import Tensor
|
20 | 20 | from torch._ops import OpOverload
|
| 21 | +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix |
21 | 22 | from torch.ao.quantization.observer import (
|
22 | 23 | HistogramObserver,
|
23 | 24 | MinMaxObserver,
|
@@ -371,14 +372,58 @@ def set_per_channel_weight_dtype(
|
371 | 372 | def set_per_channel_quant(self, enable: bool) -> None:
|
372 | 373 | self.enable_per_channel_conv_quant = enable
|
373 | 374 |
|
| 375 | + def _lift_constant_scalar_operands(self, gm: torch.fx.GraphModule) -> None: |
| 376 | + """ |
| 377 | + For the case like mul(x, 2), convert the the scalr to tensor |
| 378 | + """ |
| 379 | + for n in gm.graph.nodes: |
| 380 | + if n.op != "call_function" or n.target not in ( |
| 381 | + torch.ops.aten.add.Tensor, |
| 382 | + torch.ops.aten.sub.Tensor, |
| 383 | + torch.ops.aten.mul.Tensor, |
| 384 | + torch.ops.aten.mul.Scalar, |
| 385 | + torch.ops.aten.rsub.Scalar, |
| 386 | + ): |
| 387 | + continue |
| 388 | + |
| 389 | + const_arg = None |
| 390 | + non_const_arg = None |
| 391 | + for arg in n.args: |
| 392 | + if isinstance(arg, torch.fx.Node): |
| 393 | + non_const_arg = arg |
| 394 | + else: |
| 395 | + const_arg = arg |
| 396 | + |
| 397 | + if non_const_arg is None or const_arg is None: |
| 398 | + continue |
| 399 | + |
| 400 | + tensor_constant = torch.tensor([const_arg], dtype=torch.float32) |
| 401 | + tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")( |
| 402 | + gm |
| 403 | + ) |
| 404 | + gm.register_buffer(tensor_constant_name, tensor_constant) |
| 405 | + |
| 406 | + fake_mode = n.meta["val"].fake_mode |
| 407 | + with gm.graph.inserting_before(n): |
| 408 | + get_attr_node = gm.graph.get_attr(tensor_constant_name) |
| 409 | + get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant) |
| 410 | + |
| 411 | + if n.target == torch.ops.aten.rsub.Scalar: |
| 412 | + n.args = (get_attr_node, non_const_arg) + n.args[2:] |
| 413 | + n.target = torch.ops.aten.sub.Tensor |
| 414 | + else: |
| 415 | + n.args = (non_const_arg, get_attr_node) + n.args[2:] |
| 416 | + |
| 417 | + gm.recompile() |
| 418 | + |
374 | 419 | def transform_for_annotation(self, model: GraphModule) -> GraphModule:
|
375 | 420 | model = RemoveClone()(model).graph_module
|
376 | 421 | model = ReduceDynamicRange()(model).graph_module
|
377 | 422 | model = ConvertHardsigmoid(quantization_capture=True)(model).graph_module
|
378 | 423 | model = DecomposeScaledDotProductAttention()(model).graph_module
|
379 | 424 | model = DecomposeSilu()(model).graph_module
|
380 | 425 | model = ReplaceInfBuffer()(model).graph_module
|
381 |
| - |
| 426 | + self._lift_constant_scalar_operands(model) |
382 | 427 | return model
|
383 | 428 |
|
384 | 429 | def validate(self, model: GraphModule) -> None:
|
|
0 commit comments