Skip to content

Commit 0f3d42f

Browse files
ydwu4facebook-github-bot
authored andcommitted
Constraint based upper bound memory planning. (#264)
Summary: Pull Request resolved: #264 Before this PR, we treat inputs that passed in for tracing as the upper bound tensor itself. We use their shape as the uppper bound. However, 1. this approach doesn't respect user provided constraints. 2. this approach cannot handle certain data-dependent op such as tensor.item() because we don't know the exact value at export time. This PR adds ConstraintBasedSymShapeEval pass, which uses the rangeAnalysis infra in symbolic shape. This change has several implications. User must provide a concrete integer as upper bound for 1. dynamic input tensors. 2. outputs of data dependent operations (e.g. nonzero, .item()). Otherwise, upper bound memory planning cannot proceed because by default the uppper bound is infinity. Reviewed By: angelayi Differential Revision: D49158290 fbshipit-source-id: bab88241e1a14a07903cab5b6e20e290edf2acc5
1 parent a7e9dcf commit 0f3d42f

File tree

6 files changed

+113
-4
lines changed

6 files changed

+113
-4
lines changed

exir/capture/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@ python_library(
3737
"//executorch/exir:pass_manager",
3838
"//executorch/exir:tracer",
3939
"//executorch/exir/passes:lib",
40+
"//executorch/exir/passes:sym_shape_eval_pass",
4041
],
4142
)

exir/capture/_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1111
from executorch.exir.pass_manager import PassType
1212
from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass
13+
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
1314
from executorch.exir.tracer import ExirDynamoConfig
1415
from torch.fx._compatibility import compatibility
1516

@@ -63,3 +64,4 @@ class ExecutorchBackendConfig:
6364
# If provided, the minimum alignment of delegate data in the program. Must
6465
# be a power of 2. If not provided, uses the value in the schema file.
6566
delegate_alignment: Optional[int] = None
67+
sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass()

exir/emit/test/test_emit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.exir.emit import emit_program # noqa
2222
from executorch.exir.error import InternalError
2323
from executorch.exir.passes.const_prop_pass import ConstPropPass
24+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2425
from executorch.exir.print_program import pretty_print, print_program # noqa
2526
from executorch.exir.schema import (
2627
Bool,
@@ -46,6 +47,8 @@
4647
from functorch.experimental import control_flow
4748
from torch import nn
4849

50+
from torch.export import dynamic_dim
51+
4952

5053
class TestEmit(unittest.TestCase):
5154
@classmethod
@@ -1050,6 +1053,41 @@ def make_program(
10501053
merged_program2.execution_plan[2], merged_program.execution_plan[2]
10511054
)
10521055

1056+
def test_upper_bound_memory_planning_respect_input_constraints(self) -> None:
1057+
def func(k: torch.Tensor) -> torch.Tensor:
1058+
k = torch.cat((k, torch.ones(1, 4)))
1059+
return k
1060+
1061+
k = torch.rand(2, 4)
1062+
constraints = [
1063+
dynamic_dim(k, 0) <= 3,
1064+
]
1065+
captured = exir.capture(
1066+
func,
1067+
(k,),
1068+
exir.CaptureConfig(pt2_mode=True, enable_aot=True),
1069+
constraints=constraints, # enable_aot=False works
1070+
)
1071+
edge = captured.to_edge()
1072+
from executorch.exir.passes import MemoryPlanningPass
1073+
1074+
config = exir.ExecutorchBackendConfig(
1075+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
1076+
memory_planning_pass=MemoryPlanningPass(
1077+
memory_planning_algo="greedy",
1078+
# allow_lifetime_and_storage_overlap: bool = False,
1079+
alloc_graph_input=True,
1080+
alloc_graph_output=False,
1081+
),
1082+
)
1083+
1084+
exe_prog = edge.to_executorch(config)
1085+
program = exe_prog.program
1086+
exir.print_program.pretty_print(exe_prog.program.execution_plan)
1087+
execution_plan = program.execution_plan[0]
1088+
self.check_tensor_buffer_loc(0, execution_plan.values, 0, 1, 0)
1089+
self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48)
1090+
10531091
def test_emit_prims(self) -> None:
10541092
class Simple(torch.nn.Module):
10551093
def __init__(self) -> None:

exir/passes/sym_shape_eval_pass.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.utils._pytree as pytree
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import PassBase, PassResult
13-
from executorch.exir.sym_util import eval_expr, eval_shape
13+
from executorch.exir.sym_util import eval_expr, eval_shape, eval_upper_bound
1414
from executorch.exir.tensor import TensorSpec
1515
from sympy import Integer
1616
from torch.fx import GraphModule
@@ -96,3 +96,38 @@ def get_val(arg):
9696
spec.shape = concrete_shape
9797
spec.stride = concrete_spec
9898
return PassResult(graph_module, True)
99+
100+
101+
class ConstraintBasedSymShapeEvalPass(PassBase):
102+
"""
103+
If we enable dynamic shape tracing, a tensor's shape may become a symbolic
104+
formula. We should convert those symbolic formula to concrete value for
105+
static/upperbound tensors so we can properly do memory planning for them.
106+
107+
Not inherit from ExportPass since we simply need a way to iterate thru
108+
every node's output. PassBase is easier for that purpose.
109+
"""
110+
111+
def call(self, graph_module: GraphModule):
112+
for subgm in graph_module.modules():
113+
if not isinstance(subgm, GraphModule):
114+
continue
115+
for node in subgm.graph.nodes:
116+
for spec in pytree.tree_flatten(node.meta.get("spec", []))[0]:
117+
# Node for function like aten.sym_size does not have spec
118+
if isinstance(spec, TensorSpec):
119+
concrete_shape = [eval_upper_bound(s) for s in spec.shape]
120+
concrete_stride = [eval_upper_bound(s) for s in spec.stride]
121+
if any(not isinstance(s, int) for s in concrete_shape) or any(
122+
not isinstance(s, int) for s in concrete_stride
123+
):
124+
raise RuntimeError(
125+
f"Cannot evalute the shape upper bound of a dynamic-shaped tensor to a concrete bounded integer. Got tensor spec: {spec}."
126+
f"The upper bound shape we get {concrete_shape}, the upper bound stride we get {concrete_stride}"
127+
"This tensor could either be from 1. a data-dependent operation such as nonzero. Or 2. an input, whose don't have a constraint for the upper bound."
128+
"Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this."
129+
)
130+
131+
spec.shape = concrete_shape # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]`
132+
spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]`
133+
return PassResult(graph_module, True)

exir/program/_program.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from executorch.exir.passes import (
2020
aten_to_edge_passes,
2121
EdgeToBackendOpsPass,
22-
HintBasedSymShapeEvalPass,
2322
OpReplacePass,
2423
)
2524
from executorch.exir.passes.remove_assert_async_pass import RemoveAssertAsyncPass
@@ -309,7 +308,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
309308
SpecPropPass(),
310309
EdgeToBackendOpsPass(),
311310
RemoveAssertAsyncPass(),
312-
HintBasedSymShapeEvalPass(),
311+
config.sym_shape_eval_pass,
313312
config.to_out_var_pass,
314313
config.memory_planning_pass,
315314
]

exir/sym_util.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional, Set, Union
7+
from typing import List, Optional, Set, Union
88

99
import sympy
1010

1111
import torch
12+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
1213

1314

1415
def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
@@ -28,6 +29,32 @@ def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
2829
return int(output)
2930

3031

32+
def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]:
33+
"""
34+
Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's
35+
upper bound can not be evaluated to valid integer according to the constraints in shape_env.
36+
"""
37+
if isinstance(maybe_symint, int):
38+
return maybe_symint
39+
node = maybe_symint.node
40+
shape_env = node.shape_env
41+
expr = node.expr
42+
var_range: ValueRanges = bound_sympy(expr, shape_env.var_to_range)
43+
upper_bound = var_range.upper
44+
if isinstance(upper_bound, sympy.Integer):
45+
concrete_upper = int(var_range.upper) # pyre-ignore
46+
assert isinstance(
47+
concrete_upper, int
48+
), f"Expect upper bound to be a concrete int but got {concrete_upper}"
49+
return concrete_upper
50+
elif isinstance(upper_bound, sympy.oo):
51+
return None
52+
else:
53+
raise RuntimeError(
54+
f"Expect upper bound to be sympy.Integer or sympy.oo. but got {upper_bound}"
55+
)
56+
57+
3158
def eval_shape(shape):
3259
"""
3360
Shape maybe immutable so we return a new shape. Return None for
@@ -39,6 +66,13 @@ def eval_shape(shape):
3966
return new_shape
4067

4168

69+
def eval_shape_upper_bound(shape) -> List[int]:
70+
new_shape = []
71+
for _, s in enumerate(shape):
72+
new_shape.append(eval_upper_bound(s))
73+
return new_shape
74+
75+
4276
def collect_free_symbols(shape) -> Set[sympy.Symbol]:
4377
symset = set()
4478
for sz in shape:

0 commit comments

Comments
 (0)