Skip to content

Add support for unbacked symints #4326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ python_library(
"sym_shape_eval_pass.py",
],
deps = [
"fbsource//third-party/pypi/sympy:sympy",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir:sym_util",
Expand Down
58 changes: 34 additions & 24 deletions exir/passes/sym_shape_eval_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Callable, List, Optional

import torch
Expand All @@ -12,7 +14,6 @@
from executorch.exir.pass_base import PassBase, PassResult
from executorch.exir.sym_util import eval_expr, eval_shape, eval_upper_bound
from executorch.exir.tensor import TensorSpec
from sympy import Integer
from torch.fx import GraphModule

upper_bound_shape_inference_table = {}
Expand Down Expand Up @@ -197,29 +198,38 @@ def call(self, graph_module: GraphModule):
if any(s is None for s in concrete_shape) or any(
s is None for s in concrete_spec
):

def get_val(arg):
assert "val" in arg.meta and isinstance(
arg.meta["val"], torch.Tensor
)
return arg.meta["val"]

# TODO (yidi): Replace with range based shape inference using var_to_range.
concrete_shape = upper_bound_shape_inference_table[
node.target
](*pytree.tree_map(get_val, (node.args, node.kwargs)))

for sym_int, i in zip(spec.shape, concrete_shape):
if isinstance(sym_int, torch.SymInt):
# We cache the symbolic ints' value as the concrete interger upper bounds.
# So that future use of the unbacked symbols will get a concrete value.
sym_int.node.shape_env.var_to_val[
sym_int.node._expr
] = Integer(i)

# spec.stride is guaranteed to use a subset of symbols in spec.shape, since
# we cached the map between symbols and the concrete upper bounds. Can directly eval here.
concrete_spec = eval_shape(spec.stride)
# None indicates unbacked symints, see: https://fburl.com/code/v7hj5zv6
# Use value range to get the upper bounds of unbacked symints.
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(node.meta.get("val"))
if fake_mode is not None:
from torch.utils._sympy.numbers import int_oo

shape_env = fake_mode.shape_env
for i, v in enumerate(spec.shape):
if concrete_shape[i] is None:
# get updated shape from var_to_range
_value_range = shape_env.var_to_range[
v._sympy_() # pyre-fixme[16] Undefined attribute: `int` has no attribute `_sympy_`.
]
# cannot handle unbounded, unbacked symints; add a range to bound it.
assert _value_range.upper is not int_oo
concrete_shape[i] = int(_value_range.upper)
for i, v in enumerate(spec.stride):
if concrete_spec[i] is None:
_expr = (
v.node.expr # pyre-fixme[16] Undefined attribute: `int` has no attribute `node`.
)
_value_range = v.node.shape_env.var_to_range
from torch.utils._sympy.value_ranges import (
bound_sympy,
)

_bound_sympy = bound_sympy(_expr, _value_range)
# cannot handle unbounded, unbacked symints; add a range to bound it.
assert _bound_sympy.upper is not int_oo
concrete_spec[i] = int(_bound_sympy.upper)

assert all(isinstance(s, int) for s in concrete_shape) and all(
isinstance(s, int) for s in concrete_spec
Expand Down
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ python_library(
"//caffe2:torch",
"//executorch/exir:dim_order_utils",
"//executorch/exir:lib",
"//executorch/exir/capture:config",
],
)

Expand Down
17 changes: 17 additions & 0 deletions exir/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@
# TODO: add one more test for data dependent op plus repeat


class TensorItem(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor:
h = arg1.item()
w = arg2.item()
torch._check(h >= 2)
torch._check(h <= 100)
torch._check(w >= 2)
torch._check(w <= 100)
return torch.ones(int(h), int(w))

def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
return (torch.tensor(10), torch.tensor(20))


class Repeat(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down
27 changes: 26 additions & 1 deletion exir/tests/test_dynamic_shape_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from unittest import TestCase

from executorch import exir
from executorch.exir import to_edge
from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
from executorch.exir.tests.models import Repeat
from executorch.exir.tests.models import Repeat, TensorItem
from torch.export import export


Expand Down Expand Up @@ -37,3 +39,26 @@ def test_repeat(self):
self.assertTrue(first_spec.is_upper_bound_tensor)
self.assertTrue(second_spec.is_upper_bound_tensor)
self.assertEqual(first_spec.shape, [4, 5])


class TestUnbackedSymInt(TestCase):
def test_unbacked_symint(self):
eager_model = TensorItem()
inputs = eager_model.get_random_inputs()
inputs = inputs[0], inputs[1]

prog = to_edge(
export(eager_model, inputs, dynamic_shapes=None),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
gm = new_prog.exported_program().graph_module

DebugPass(show_spec=True)(gm)
*_, return_node = gm.graph.nodes
speclist = return_node.meta["spec"]
self.assertEqual(len(speclist), 1)
self.assertTrue(speclist[0].is_upper_bound_tensor)
self.assertEqual(
speclist[0].shape, [100, 100]
) # upper bound of TensorItem model
Loading