Skip to content

Commit c364a2c

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Add small repro test for unsigned -> signed et loss error (#8506)
Summary: There was a difference in behavior from `quantized_decomposed.quantize_per_tensor` and `cadence.quantize_per_tensor`, specifically how rounding half values worked. The former rounds towards even (based on `torch.round` which does that). The latter rounds away from zero. Make sure the python implementation matches the Executorch implementation in this regard. Reviewed By: sabarishsnk Differential Revision: D69668881
1 parent cb2b174 commit c364a2c

File tree

5 files changed

+20
-5
lines changed

5 files changed

+20
-5
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ python_library(
180180
typing = True,
181181
deps = [
182182
"//caffe2:torch",
183+
":ops_registrations",
183184
":compiler_utils",
184185
"//executorch/backends/cadence/aot:pass_utils",
185186
"//executorch/backends/cadence/aot:utils",

backends/cadence/aot/fuse_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import torch
2020
import torch.fx
21+
# Import these for the cadence function signatures.
22+
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
2123
from executorch.backends.cadence.aot.compiler_utils import (
2224
broadcastable,
2325
get_cascaded_ops,
@@ -849,7 +851,7 @@ def attempt_fusion(
849851
if isinstance(arg, torch.fx.Node)
850852
and isinstance(arg.target, EdgeOpOverload)
851853
and get_edge_overload_packet(arg.target)
852-
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor
854+
in (exir_ops.edge.quantized_decomposed.dequantize_per_tensor, exir_ops.edge.cadence.dequantize_per_tensor)
853855
]
854856
multiplier_nodes = [
855857
arg

backends/cadence/aot/remove_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,8 @@ class Subgraph:
569569
exir_ops.edge.aten.hardtanh.default,
570570
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
571571
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
572+
exir_ops.edge.cadence.quantize_per_tensor.default,
573+
exir_ops.edge.cadence.dequantize_per_tensor.default,
572574
}
573575

574576
# must be initialized in the constructor

backends/cadence/aot/reorder_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def get_descendent_quant_ops(self, node: torch.fx.Node) -> List[torch.fx.Node]:
118118
if user_target in {
119119
torch.ops.quantized_decomposed.quantize_per_tensor,
120120
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
121+
torch.ops.cadence.quantize_per_tensor,
122+
exir_ops.edge.cadence.quantize_per_tensor,
121123
}:
122124
descendent_quant_ops.append(user)
123125
# If the successor is a trivially quantizable op, consider its users
@@ -300,6 +302,8 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
300302
if get_overload_packet(node.target) not in (
301303
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
302304
torch.ops.quantized_decomposed.quantize_per_tensor,
305+
exir_ops.edge.cadence.quantize_per_tensor,
306+
torch.ops.cadence.quantize_per_tensor,
303307
):
304308
continue
305309

@@ -413,6 +417,7 @@ def postponing_feasible(self, dequant_node: torch.fx.Node):
413417
in {
414418
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
415419
exir_ops.edge.quantized_decomposed.quantize_per_channel,
420+
exir_ops.edge.cadence.quantize_per_tensor,
416421
}
417422
for x in users
418423
)
@@ -422,6 +427,7 @@ def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
422427
packet_to_overload_map = {
423428
exir_ops.edge.quantized_decomposed.dequantize_per_tensor: "default",
424429
exir_ops.edge.quantized_decomposed.dequantize_per_channel: "default",
430+
exir_ops.edge.cadence.dequantize_per_tensor: "default",
425431
}
426432
graph = graph_module.graph
427433
modified = False
@@ -500,6 +506,7 @@ class SinkOpsCloserToUsePass(ExportPass):
500506
exir_ops.edge.aten.dequantize,
501507
exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
502508
exir_ops.edge.quantized_decomposed.dequantize_per_channel,
509+
exir_ops.edge.cadence.dequantize_per_tensor,
503510
}
504511

505512
def sink_ops_closer_to_use(self, graph_module: torch.fx.GraphModule):
@@ -558,6 +565,7 @@ class HoistOpsCloserToDefPass(ExportPass):
558565

559566
hoistable_ops: Set[EdgeOpOverload] = {
560567
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
568+
exir_ops.edge.cadence.quantize_per_tensor,
561569
exir_ops.edge.aten.slice_copy,
562570
exir_ops.edge.aten.select_copy,
563571
}

backends/cadence/aot/replace_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ def call_operator(
162162
kwargs: Dict[str, Argument],
163163
meta: NodeMetadata,
164164
) -> ProxyValue:
165-
if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}:
165+
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
166+
if op != ns.quantized_decomposed.quantize_per_tensor.default:
166167
return super().call_operator(op, args, kwargs, meta)
167168

168169
return super().call_operator(
169-
exir_ops.edge.cadence.quantize_per_tensor.default,
170+
ns.cadence.quantize_per_tensor.default,
170171
args,
171172
kwargs,
172173
meta,
@@ -188,11 +189,12 @@ def call_operator(
188189
kwargs: Dict[str, Argument],
189190
meta: NodeMetadata,
190191
) -> ProxyValue:
191-
if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}:
192+
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
193+
if op != ns.quantized_decomposed.dequantize_per_tensor.default:
192194
return super().call_operator(op, args, kwargs, meta)
193195

194196
return super().call_operator(
195-
exir_ops.edge.cadence.dequantize_per_tensor.default,
197+
ns.cadence.dequantize_per_tensor.default,
196198
args,
197199
kwargs,
198200
meta,

0 commit comments

Comments
 (0)