Skip to content

Commit 24f2b8e

Browse files
ydwu4facebook-github-bot
authored andcommitted
Rename SymShapeEval pass to HintBasedSymShapeEval pass and add warning for it. (#377)
Summary: As titled. The original SymShapeEvalPass doesn't respect user constraints and is based on hints of the symbolics. This diff changes its name to HintBasedSymShapeEval pass to make it clearer. Differential Revision: D49335965
1 parent bfa89be commit 24f2b8e

File tree

5 files changed

+21
-10
lines changed

5 files changed

+21
-10
lines changed

exir/passes/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from executorch.exir.passes.replace_sym_size_op_pass import ReplaceSymSizeOpPass
5252
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
5353
from executorch.exir.passes.spec_prop_pass import SpecPropPass
54-
from executorch.exir.passes.sym_shape_eval_pass import SymShapeEvalPass
54+
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
5555
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
5656
from torch import fx
5757
from torch._subclasses import FakeTensor
@@ -65,7 +65,7 @@
6565
"OpReplacePass",
6666
"EdgeToBackendOpsPass",
6767
"MemoryFormatOpsPass",
68-
"SymShapeEvalPass",
68+
"HintBasedSymShapeEvalPass",
6969
]
7070

7171
Argument = Optional[
@@ -510,5 +510,5 @@ def propagate_dynamic_shape(
510510
"""
511511
return [
512512
SpecPropPass(),
513-
SymShapeEvalPass(),
513+
HintBasedSymShapeEvalPass(),
514514
]

exir/passes/sym_shape_eval_pass.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,23 @@ def nonzero(args, kwargs) -> List[Optional[int]]:
3232
return [eval_expr(args[0].shape[0]), len(args[0].shape)]
3333

3434

35-
class SymShapeEvalPass(PassBase):
35+
class HintBasedSymShapeEvalPass(PassBase):
3636
"""
3737
If we enable dynamic shape tracing, a tensor's shape may become a symbolic
3838
formula. We should convert those symbolic formula to concrete value for
3939
static/upperbound tensors so we can properly do memory planning for them.
4040
41+
HintBasedSymShapeEvalPass evalutes the symbolic expression of shapes based
42+
on its hint, which is a concrete integer that backs the sym expression. The original
43+
hint comes from the sizes of the inputs that user uses for tracing and hints of
44+
symbolic expressions are propagated via meta tensor computation.
45+
For example, when export f(x), we use x = torch.ones(3, 4) as an exmaple input to f and
46+
suppose we constrain both dimensions of x as dynamic. We'll have two symbols s0, s1 created
47+
and they are backed up with hints 3 and 4 respectively. If there is a y = x[0] operation in f,
48+
the shape of y is inferred to be s1, which is backed up with hint 4.
49+
50+
Warning: if you're using torch.export with constrain API, this method doesn't respect the input constraints.
51+
4152
Not inherit from ExportPass since we simply need a way to iterate thru
4253
every node's output. PassBase is easier for that purpose.
4354
"""

exir/program/_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from executorch.exir.passes import (
2020
aten_to_edge_passes,
2121
EdgeToBackendOpsPass,
22+
HintBasedSymShapeEvalPass,
2223
OpReplacePass,
23-
SymShapeEvalPass,
2424
)
2525
from executorch.exir.passes.remove_assert_async_pass import RemoveAssertAsyncPass
2626
from executorch.exir.passes.spec_prop_pass import SpecPropPass
@@ -309,7 +309,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
309309
SpecPropPass(),
310310
EdgeToBackendOpsPass(),
311311
RemoveAssertAsyncPass(),
312-
SymShapeEvalPass(),
312+
HintBasedSymShapeEvalPass(),
313313
config.to_out_var_pass,
314314
config.memory_planning_pass,
315315
]

exir/tests/test_dynamic_shape_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from unittest import TestCase
88

99
from executorch import exir
10-
from executorch.exir.passes import DebugPass, SpecPropPass, SymShapeEvalPass
10+
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
1111
from executorch.exir.tests.models import Repeat
1212

1313

@@ -23,7 +23,7 @@ def test_repeat(self):
2323
exir.CaptureConfig(enable_dynamic_shape=True),
2424
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
2525

26-
new_prog = prog.transform(SpecPropPass(), SymShapeEvalPass())
26+
new_prog = prog.transform(SpecPropPass(), HintBasedSymShapeEvalPass())
2727

2828
gm = new_prog.exported_program.graph_module
2929

exir/tests/test_passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from executorch.exir.passes import (
2525
dead_code_elimination_pass,
2626
DebugPass,
27+
HintBasedSymShapeEvalPass,
2728
MemoryPlanningPass,
2829
propagate_dynamic_shape,
2930
RemoveNoopPass,
3031
ReplaceSymSizeOpPass,
31-
SymShapeEvalPass,
3232
ToOutVarPass,
3333
)
3434
from executorch.exir.passes.const_prop_pass import ConstPropPass
@@ -586,7 +586,7 @@ def test_alloc_node_spec(self) -> None:
586586
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
587587
passes = [
588588
SpecPropPass(),
589-
SymShapeEvalPass(),
589+
HintBasedSymShapeEvalPass(),
590590
ToOutVarPass(),
591591
MemoryPlanningPass("greedy"),
592592
]

0 commit comments

Comments
 (0)