Skip to content

Commit 0e032c5

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add support for unbacked symints (#4326)
Summary: Pull Request resolved: #4326 Use ranges to get the upper bound for unbacked symints. This will be used to resolve unbacked symints in preprocess. Reviewed By: larryliu0820, ydwu4, angelayi Differential Revision: D60027561 fbshipit-source-id: bf3c29151e30e7056049ccb1908483dd44993e10
1 parent 0e2b205 commit 0e032c5

File tree

5 files changed

+78
-26
lines changed

5 files changed

+78
-26
lines changed

exir/passes/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ python_library(
190190
"sym_shape_eval_pass.py",
191191
],
192192
deps = [
193-
"fbsource//third-party/pypi/sympy:sympy",
194193
"//caffe2:torch",
195194
"//executorch/exir:pass_base",
196195
"//executorch/exir:sym_util",

exir/passes/sym_shape_eval_pass.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from typing import Callable, List, Optional
810

911
import torch
@@ -12,7 +14,6 @@
1214
from executorch.exir.pass_base import PassBase, PassResult
1315
from executorch.exir.sym_util import eval_expr, eval_shape, eval_upper_bound
1416
from executorch.exir.tensor import TensorSpec
15-
from sympy import Integer
1617
from torch.fx import GraphModule
1718

1819
upper_bound_shape_inference_table = {}
@@ -197,29 +198,38 @@ def call(self, graph_module: GraphModule):
197198
if any(s is None for s in concrete_shape) or any(
198199
s is None for s in concrete_spec
199200
):
200-
201-
def get_val(arg):
202-
assert "val" in arg.meta and isinstance(
203-
arg.meta["val"], torch.Tensor
204-
)
205-
return arg.meta["val"]
206-
207-
# TODO (yidi): Replace with range based shape inference using var_to_range.
208-
concrete_shape = upper_bound_shape_inference_table[
209-
node.target
210-
](*pytree.tree_map(get_val, (node.args, node.kwargs)))
211-
212-
for sym_int, i in zip(spec.shape, concrete_shape):
213-
if isinstance(sym_int, torch.SymInt):
214-
# We cache the symbolic ints' value as the concrete interger upper bounds.
215-
# So that future use of the unbacked symbols will get a concrete value.
216-
sym_int.node.shape_env.var_to_val[
217-
sym_int.node._expr
218-
] = Integer(i)
219-
220-
# spec.stride is guaranteed to use a subset of symbols in spec.shape, since
221-
# we cached the map between symbols and the concrete upper bounds. Can directly eval here.
222-
concrete_spec = eval_shape(spec.stride)
201+
# None indicates unbacked symints, see: https://fburl.com/code/v7hj5zv6
202+
# Use value range to get the upper bounds of unbacked symints.
203+
from torch._guards import detect_fake_mode
204+
205+
fake_mode = detect_fake_mode(node.meta.get("val"))
206+
if fake_mode is not None:
207+
from torch.utils._sympy.numbers import int_oo
208+
209+
shape_env = fake_mode.shape_env
210+
for i, v in enumerate(spec.shape):
211+
if concrete_shape[i] is None:
212+
# get updated shape from var_to_range
213+
_value_range = shape_env.var_to_range[
214+
v._sympy_() # pyre-fixme[16] Undefined attribute: `int` has no attribute `_sympy_`.
215+
]
216+
# cannot handle unbounded, unbacked symints; add a range to bound it.
217+
assert _value_range.upper is not int_oo
218+
concrete_shape[i] = int(_value_range.upper)
219+
for i, v in enumerate(spec.stride):
220+
if concrete_spec[i] is None:
221+
_expr = (
222+
v.node.expr # pyre-fixme[16] Undefined attribute: `int` has no attribute `node`.
223+
)
224+
_value_range = v.node.shape_env.var_to_range
225+
from torch.utils._sympy.value_ranges import (
226+
bound_sympy,
227+
)
228+
229+
_bound_sympy = bound_sympy(_expr, _value_range)
230+
# cannot handle unbounded, unbacked symints; add a range to bound it.
231+
assert _bound_sympy.upper is not int_oo
232+
concrete_spec[i] = int(_bound_sympy.upper)
223233

224234
assert all(isinstance(s, int) for s in concrete_shape) and all(
225235
isinstance(s, int) for s in concrete_spec

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ python_library(
388388
"//caffe2:torch",
389389
"//executorch/exir:dim_order_utils",
390390
"//executorch/exir:lib",
391+
"//executorch/exir/capture:config",
391392
],
392393
)
393394

exir/tests/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,23 @@
2121
# TODO: add one more test for data dependent op plus repeat
2222

2323

24+
class TensorItem(nn.Module):
25+
def __init__(self) -> None:
26+
super().__init__()
27+
28+
def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor:
29+
h = arg1.item()
30+
w = arg2.item()
31+
torch._check(h >= 2)
32+
torch._check(h <= 100)
33+
torch._check(w >= 2)
34+
torch._check(w <= 100)
35+
return torch.ones(int(h), int(w))
36+
37+
def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
38+
return (torch.tensor(10), torch.tensor(20))
39+
40+
2441
class Repeat(nn.Module):
2542
def __init__(self) -> None:
2643
super().__init__()

exir/tests/test_dynamic_shape_propagation.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
from unittest import TestCase
810

911
from executorch import exir
1012
from executorch.exir import to_edge
1113
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
12-
from executorch.exir.tests.models import Repeat
14+
from executorch.exir.tests.models import Repeat, TensorItem
1315
from torch.export import export
1416

1517

@@ -37,3 +39,26 @@ def test_repeat(self):
3739
self.assertTrue(first_spec.is_upper_bound_tensor)
3840
self.assertTrue(second_spec.is_upper_bound_tensor)
3941
self.assertEqual(first_spec.shape, [4, 5])
42+
43+
44+
class TestUnbackedSymInt(TestCase):
45+
def test_unbacked_symint(self):
46+
eager_model = TensorItem()
47+
inputs = eager_model.get_random_inputs()
48+
inputs = inputs[0], inputs[1]
49+
50+
prog = to_edge(
51+
export(eager_model, inputs, dynamic_shapes=None),
52+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
53+
)
54+
new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
55+
gm = new_prog.exported_program().graph_module
56+
57+
DebugPass(show_spec=True)(gm)
58+
*_, return_node = gm.graph.nodes
59+
speclist = return_node.meta["spec"]
60+
self.assertEqual(len(speclist), 1)
61+
self.assertTrue(speclist[0].is_upper_bound_tensor)
62+
self.assertEqual(
63+
speclist[0].shape, [100, 100]
64+
) # upper bound of TensorItem model

0 commit comments

Comments
 (0)