Skip to content

Rename SymShapeEval pass to HintBasedSymShapeEval pass and add warning for it. #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from executorch.exir.passes.replace_sym_size_op_pass import ReplaceSymSizeOpPass
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.passes.sym_shape_eval_pass import SymShapeEvalPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
from torch import fx
from torch._subclasses import FakeTensor
Expand All @@ -65,7 +65,7 @@
"OpReplacePass",
"EdgeToBackendOpsPass",
"MemoryFormatOpsPass",
"SymShapeEvalPass",
"HintBasedSymShapeEvalPass",
]

Argument = Optional[
Expand Down Expand Up @@ -510,5 +510,5 @@ def propagate_dynamic_shape(
"""
return [
SpecPropPass(),
SymShapeEvalPass(),
HintBasedSymShapeEvalPass(),
]
13 changes: 12 additions & 1 deletion exir/passes/sym_shape_eval_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,23 @@ def nonzero(args, kwargs) -> List[Optional[int]]:
return [eval_expr(args[0].shape[0]), len(args[0].shape)]


class SymShapeEvalPass(PassBase):
class HintBasedSymShapeEvalPass(PassBase):
"""
If we enable dynamic shape tracing, a tensor's shape may become a symbolic
formula. We should convert those symbolic formula to concrete value for
static/upperbound tensors so we can properly do memory planning for them.
HintBasedSymShapeEvalPass evalutes the symbolic expression of shapes based
on its hint, which is a concrete integer that backs the sym expression. The original
hint comes from the sizes of the inputs that user uses for tracing and hints of
symbolic expressions are propagated via meta tensor computation.
For example, when export f(x), we use x = torch.ones(3, 4) as an exmaple input to f and
suppose we constrain both dimensions of x as dynamic. We'll have two symbols s0, s1 created
and they are backed up with hints 3 and 4 respectively. If there is a y = x[0] operation in f,
the shape of y is inferred to be s1, which is backed up with hint 4.
Warning: if you're using torch.export with constrain API, this method doesn't respect the input constraints.
Not inherit from ExportPass since we simply need a way to iterate thru
every node's output. PassBase is easier for that purpose.
"""
Expand Down
4 changes: 2 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from executorch.exir.passes import (
aten_to_edge_passes,
EdgeToBackendOpsPass,
HintBasedSymShapeEvalPass,
OpReplacePass,
SymShapeEvalPass,
)
from executorch.exir.passes.remove_assert_async_pass import RemoveAssertAsyncPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
Expand Down Expand Up @@ -309,7 +309,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
SpecPropPass(),
EdgeToBackendOpsPass(),
RemoveAssertAsyncPass(),
SymShapeEvalPass(),
HintBasedSymShapeEvalPass(),
config.to_out_var_pass,
config.memory_planning_pass,
]
Expand Down
4 changes: 2 additions & 2 deletions exir/tests/test_dynamic_shape_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from unittest import TestCase

from executorch import exir
from executorch.exir.passes import DebugPass, SpecPropPass, SymShapeEvalPass
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
from executorch.exir.tests.models import Repeat


Expand All @@ -23,7 +23,7 @@ def test_repeat(self):
exir.CaptureConfig(enable_dynamic_shape=True),
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))

new_prog = prog.transform(SpecPropPass(), SymShapeEvalPass())
new_prog = prog.transform(SpecPropPass(), HintBasedSymShapeEvalPass())

gm = new_prog.exported_program.graph_module

Expand Down
4 changes: 2 additions & 2 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from executorch.exir.passes import (
dead_code_elimination_pass,
DebugPass,
HintBasedSymShapeEvalPass,
MemoryPlanningPass,
propagate_dynamic_shape,
RemoveNoopPass,
ReplaceSymSizeOpPass,
SymShapeEvalPass,
ToOutVarPass,
)
from executorch.exir.passes.const_prop_pass import ConstPropPass
Expand Down Expand Up @@ -586,7 +586,7 @@ def test_alloc_node_spec(self) -> None:
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
passes = [
SpecPropPass(),
SymShapeEvalPass(),
HintBasedSymShapeEvalPass(),
ToOutVarPass(),
MemoryPlanningPass("greedy"),
]
Expand Down