Skip to content

Commit a7e9dcf

Browse files
ydwu4facebook-github-bot
authored andcommitted
Rename SymShapeEval pass to HintBasedSymShapeEval pass and add warning for it. (#377)
Summary: Pull Request resolved: #377 As titled. The original SymShapeEvalPass doesn't respect user constraints and is based on hints of the symbolics. This diff changes its name to HintBasedSymShapeEval pass to make it clearer. Reviewed By: angelayi Differential Revision: D49335965 fbshipit-source-id: 8dbacc666276c46653dbaa2b8d23c81bee70c507
1 parent c52000a commit a7e9dcf

File tree

5 files changed

+21
-10
lines changed

5 files changed

+21
-10
lines changed

exir/passes/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from executorch.exir.passes.replace_sym_size_op_pass import ReplaceSymSizeOpPass
5252
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
5353
from executorch.exir.passes.spec_prop_pass import SpecPropPass
54-
from executorch.exir.passes.sym_shape_eval_pass import SymShapeEvalPass
54+
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
5555
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
5656
from torch import fx
5757
from torch._subclasses import FakeTensor
@@ -65,7 +65,7 @@
6565
"OpReplacePass",
6666
"EdgeToBackendOpsPass",
6767
"MemoryFormatOpsPass",
68-
"SymShapeEvalPass",
68+
"HintBasedSymShapeEvalPass",
6969
]
7070

7171
Argument = Optional[
@@ -510,5 +510,5 @@ def propagate_dynamic_shape(
510510
"""
511511
return [
512512
SpecPropPass(),
513-
SymShapeEvalPass(),
513+
HintBasedSymShapeEvalPass(),
514514
]

exir/passes/sym_shape_eval_pass.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,23 @@ def nonzero(args, kwargs) -> List[Optional[int]]:
3232
return [eval_expr(args[0].shape[0]), len(args[0].shape)]
3333

3434

35-
class SymShapeEvalPass(PassBase):
35+
class HintBasedSymShapeEvalPass(PassBase):
3636
"""
3737
If we enable dynamic shape tracing, a tensor's shape may become a symbolic
3838
formula. We should convert those symbolic formula to concrete value for
3939
static/upperbound tensors so we can properly do memory planning for them.
4040
41+
HintBasedSymShapeEvalPass evalutes the symbolic expression of shapes based
42+
on its hint, which is a concrete integer that backs the sym expression. The original
43+
hint comes from the sizes of the inputs that user uses for tracing and hints of
44+
symbolic expressions are propagated via meta tensor computation.
45+
For example, when export f(x), we use x = torch.ones(3, 4) as an exmaple input to f and
46+
suppose we constrain both dimensions of x as dynamic. We'll have two symbols s0, s1 created
47+
and they are backed up with hints 3 and 4 respectively. If there is a y = x[0] operation in f,
48+
the shape of y is inferred to be s1, which is backed up with hint 4.
49+
50+
Warning: if you're using torch.export with constrain API, this method doesn't respect the input constraints.
51+
4152
Not inherit from ExportPass since we simply need a way to iterate thru
4253
every node's output. PassBase is easier for that purpose.
4354
"""

exir/program/_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from executorch.exir.passes import (
2020
aten_to_edge_passes,
2121
EdgeToBackendOpsPass,
22+
HintBasedSymShapeEvalPass,
2223
OpReplacePass,
23-
SymShapeEvalPass,
2424
)
2525
from executorch.exir.passes.remove_assert_async_pass import RemoveAssertAsyncPass
2626
from executorch.exir.passes.spec_prop_pass import SpecPropPass
@@ -309,7 +309,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
309309
SpecPropPass(),
310310
EdgeToBackendOpsPass(),
311311
RemoveAssertAsyncPass(),
312-
SymShapeEvalPass(),
312+
HintBasedSymShapeEvalPass(),
313313
config.to_out_var_pass,
314314
config.memory_planning_pass,
315315
]

exir/tests/test_dynamic_shape_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from unittest import TestCase
88

99
from executorch import exir
10-
from executorch.exir.passes import DebugPass, SpecPropPass, SymShapeEvalPass
10+
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
1111
from executorch.exir.tests.models import Repeat
1212

1313

@@ -23,7 +23,7 @@ def test_repeat(self):
2323
exir.CaptureConfig(enable_dynamic_shape=True),
2424
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
2525

26-
new_prog = prog.transform(SpecPropPass(), SymShapeEvalPass())
26+
new_prog = prog.transform(SpecPropPass(), HintBasedSymShapeEvalPass())
2727

2828
gm = new_prog.exported_program.graph_module
2929

exir/tests/test_passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from executorch.exir.passes import (
2525
dead_code_elimination_pass,
2626
DebugPass,
27+
HintBasedSymShapeEvalPass,
2728
MemoryPlanningPass,
2829
propagate_dynamic_shape,
2930
RemoveNoopPass,
3031
ReplaceSymSizeOpPass,
31-
SymShapeEvalPass,
3232
ToOutVarPass,
3333
)
3434
from executorch.exir.passes.const_prop_pass import ConstPropPass
@@ -586,7 +586,7 @@ def test_alloc_node_spec(self) -> None:
586586
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
587587
passes = [
588588
SpecPropPass(),
589-
SymShapeEvalPass(),
589+
HintBasedSymShapeEvalPass(),
590590
ToOutVarPass(),
591591
MemoryPlanningPass("greedy"),
592592
]

0 commit comments

Comments
 (0)