Skip to content

Commit d0cb851

Browse files
angelayifacebook-github-bot
authored andcommitted
Add SymIntToTensorPass in to_edge
Summary: JacobSzwejbka ran into some issues when lowering to executorch where we have the following graph: ``` %symfloat_1 = executorch_prim::mul(%symfloat_a, %symfloat_b) %mul_out = aten::mul.out(%tensor, %symfloat_1, %mul_out, %mul_out) ``` Which errors because mul.out is being passed a symfloat, but mul.out only accepts 2 tensors. So if we add this pass, the symfloats will be converted to tensors. Apparently I wrote this pass back in April but never put it somewhere where it gets run by default. Reviewed By: JacobSzwejbka Differential Revision: D48415403 fbshipit-source-id: 1ce9db145feb8b54a666e38e4f7aefdb1e314a48
1 parent e05bf2a commit d0cb851

File tree

5 files changed

+20
-6
lines changed

5 files changed

+20
-6
lines changed

exir/passes/TARGETS

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ python_library(
2525
":scalar_to_tensor_pass",
2626
":spec_prop_pass",
2727
":sym_shape_eval_pass",
28+
":sym_to_tensor_pass",
2829
"//caffe2:torch",
2930
"//executorch/exir:common",
3031
"//executorch/exir:control_flow",
@@ -87,9 +88,9 @@ python_library(
8788
)
8889

8990
python_library(
90-
name = "symint_to_tensor_pass",
91+
name = "sym_to_tensor_pass",
9192
srcs = [
92-
"symint_to_tensor_pass.py",
93+
"sym_to_tensor_pass.py",
9394
],
9495
deps = [
9596
"//caffe2:torch",

exir/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
5353
from executorch.exir.passes.spec_prop_pass import SpecPropPass
5454
from executorch.exir.passes.sym_shape_eval_pass import SymShapeEvalPass
55+
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
5556
from torch import fx
5657
from torch._subclasses import FakeTensor
5758
from torch.fx.passes.infra.pass_base import PassBase, PassResult
@@ -475,6 +476,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
475476
NormalizeTransposePass(),
476477
ReplaceBrokenOpsWithFunctionalOpsPass(),
477478
ScalarToTensorPass(),
479+
SymToTensorPass(),
478480
RemoveMixedTypeOperators(),
479481
RemoveNoopPass(),
480482
dead_code_elimination_pass,

exir/passes/symint_to_tensor_pass.py renamed to exir/passes/sym_to_tensor_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.utils._pytree import PyTree
1515

1616

17-
class SymIntToTensorPass(ExportPass):
17+
class SymToTensorPass(ExportPass):
1818
"""
1919
The dispatcher implicitly converts SymInt/SymFloats to tensors, but
2020
sometimes this doesn't comply with the operator's schema which Executorch

exir/tests/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ python_unittest(
233233
"//executorch/exir/passes:replace_edge_with_backend_pass",
234234
"//executorch/exir/passes:scalar_to_tensor_pass",
235235
"//executorch/exir/passes:spec_prop_pass",
236-
"//executorch/exir/passes:symint_to_tensor_pass",
236+
"//executorch/exir/passes:sym_to_tensor_pass",
237237
"//executorch/extension/pybindings:portable", # @manual
238238
],
239239
)

exir/tests/test_passes.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
4141
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
4242
from executorch.exir.passes.spec_prop_pass import SpecPropPass
43-
from executorch.exir.passes.symint_to_tensor_pass import SymIntToTensorPass
43+
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
4444
from executorch.exir.tensor import TensorSpec
4545
from executorch.exir.tests.common import register_additional_test_aten_ops
4646
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
@@ -834,11 +834,22 @@ def f(x: torch.Tensor) -> torch.Tensor:
834834
prog = exir.capture(
835835
f, inputs, exir.CaptureConfig(enable_dynamic_shape=True)
836836
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
837-
prog = prog.transform(SymIntToTensorPass())
837+
prog = prog.transform(SymToTensorPass())
838838

839+
FileCheck().check(
840+
"executorch_exir_dialects_edge__ops_aten_scalar_tensor_default"
841+
).run(prog.exported_program.graph_module.code)
839842
self.assertTrue(torch.allclose(f(torch.ones(6)), prog(torch.ones(6))))
840843
self.assertTrue(torch.allclose(f(torch.zeros(6)), prog(torch.zeros(6))))
841844

845+
# This pass should also be part of to_edge, so checking again after to_edge
846+
prog = exir.capture(
847+
f, inputs, exir.CaptureConfig(enable_dynamic_shape=True)
848+
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
849+
FileCheck().check(
850+
"executorch_exir_dialects_edge__ops_aten_scalar_tensor_default"
851+
).run(prog.exported_program.graph_module.code)
852+
842853
def test_replace_edge_with_backend_pass(self) -> None:
843854
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
844855
z = x + y

0 commit comments

Comments
 (0)