Skip to content

Commit 569373a

Browse files
ezyangnWEIdia
authored andcommitted
Output of nonzero is transposed, fix fake tensor (pytorch#144695)
Needs this companion executorch PR: pytorch/executorch#7657 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#144695 Approved by: https://github.com/bobrenjc93, https://github.com/albanD
1 parent 530e4e0 commit 569373a

File tree

8 files changed

+68
-41
lines changed

8 files changed

+68
-41
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015
1818

1919

2020

21-
basic_modules_ListOfLinears_eager,compile_time_instruction_count,928600000,0.015
21+
basic_modules_ListOfLinears_eager,compile_time_instruction_count,943000000,0.015
2222

2323

2424

test/inductor/test_auto_functionalize.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,17 +1406,17 @@ def f(x):
14061406
"""\
14071407
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
14081408
clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1)
1409-
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
1409+
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
14101410
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
14111411
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
14121412
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
1413-
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
1413+
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
14141414
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
14151415
getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
1416-
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
1416+
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
14171417
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None
14181418
alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
1419-
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
1419+
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
14201420
return (alias_1, slice_2)""", # noqa: B950
14211421
ignore_comments=True,
14221422
ignore_empty_lines=True,
@@ -1427,19 +1427,19 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
14271427
"""\
14281428
def forward(self, arg0_1: "f32[2][1]cpu"):
14291429
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
1430-
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
1430+
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
14311431
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
14321432
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
14331433
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
14341434
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
14351435
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
1436-
_to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
1436+
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
14371437
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None
14381438
getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
1439-
getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
1439+
getitem_2: "f32[u0, 1][1, u0]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
14401440
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None
14411441
alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None
1442-
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
1442+
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None
14431443
return (alias_1, slice_2)""", # noqa: B950
14441444
ignore_comments=True,
14451445
ignore_empty_lines=True,
@@ -1452,16 +1452,16 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
14521452
graph_inductor,
14531453
"""\
14541454
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
1455-
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1)
1455+
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
14561456
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
14571457
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
14581458
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
1459-
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
1459+
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
14601460
alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1)
1461-
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
1461+
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
14621462
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
14631463
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None
1464-
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
1464+
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
14651465
return (arg1_1, slice_2)""", # noqa: B950
14661466
ignore_comments=True,
14671467
ignore_empty_lines=True,
@@ -1471,18 +1471,18 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
14711471
graph_inductor,
14721472
"""\
14731473
def forward(self, arg0_1: "f32[2][1]cpu"):
1474-
nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1)
1474+
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
14751475
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
14761476
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
14771477
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
14781478
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
14791479
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
1480-
convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
1480+
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
14811481
alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1)
1482-
alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type)
1482+
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
14831483
foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None
14841484
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None
1485-
slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
1485+
slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None
14861486
return (arg0_1, slice_2)""", # noqa: B950
14871487
ignore_comments=True,
14881488
ignore_empty_lines=True,

test/inductor/test_unbacked_symints.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,14 @@
77
from torch._dynamo import config as dynamo_config
88
from torch._inductor import config as inductor_config
99
from torch._inductor.test_case import TestCase as InductorTestCase
10-
from torch._inductor.utils import is_big_gpu
1110
from torch.testing import make_tensor
1211
from torch.testing._internal.common_device_type import (
1312
instantiate_device_type_tests,
13+
skipCPUIf,
1414
skipGPUIf,
1515
)
16-
from torch.testing._internal.common_utils import IS_LINUX, parametrize
17-
from torch.testing._internal.inductor_utils import (
18-
GPU_TYPE,
19-
HAS_CUDA,
20-
HAS_GPU,
21-
requires_gpu,
22-
)
16+
from torch.testing._internal.common_utils import parametrize
17+
from torch.testing._internal.inductor_utils import HAS_GPU
2318

2419

2520
class TestUnbackedSymints(InductorTestCase):
@@ -127,7 +122,7 @@ def fn(x):
127122
expected = fn(*example_inputs)
128123
torch.testing.assert_close(actual, expected)
129124

130-
@requires_gpu()
125+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
131126
@dynamo_config.patch({"capture_scalar_outputs": True})
132127
def test_triton_kernel_grid(self, device):
133128
if device == "cpu":
@@ -163,6 +158,7 @@ def fn(x):
163158

164159
torch.testing.assert_close(actual, expected)
165160

161+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
166162
@inductor_config.patch({"max_autotune": True})
167163
@dynamo_config.patch({"capture_scalar_outputs": True})
168164
def test_equivalent_backed_unbacked(self, device):
@@ -195,7 +191,8 @@ def fn(x, w, a, b):
195191
expected = fn(*example_inputs)
196192
torch.testing.assert_close(actual, expected)
197193

198-
@requires_gpu()
194+
@skipCPUIf(True, "precision not good enough on CPU")
195+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
199196
@dynamo_config.patch({"capture_scalar_outputs": True})
200197
def test_vertical_pointwise_reduction_fusion(self, device):
201198
# reset in case we run both cpu and cuda tests
@@ -214,16 +211,17 @@ def fn(x, y, repeats):
214211
return pointwise, reduction
215212

216213
example_inputs = (
217-
torch.randn(32, 16).to(GPU_TYPE),
218-
torch.randn(1, 16).to(GPU_TYPE),
219-
torch.tensor(32).to(GPU_TYPE),
214+
torch.randn(32, 16, device=device),
215+
torch.randn(1, 16, device=device),
216+
torch.tensor(32, device=device),
220217
)
221218

222219
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
223220
expected = fn(*example_inputs)
224221
torch.testing.assert_close(actual, expected)
225222
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
226223

224+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
227225
@dynamo_config.patch({"capture_scalar_outputs": True})
228226
@parametrize(
229227
"torch_fn", [torch.mm, torch.bmm, torch.addmm], name_fn=lambda fn: fn.__name__
@@ -262,6 +260,7 @@ def fn(x, w, repeats, is_bmm):
262260
expected = fn(*example_inputs)
263261
torch.testing.assert_close(actual, expected)
264262

263+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
265264
@torch._dynamo.config.patch(capture_scalar_outputs=True)
266265
def test_unbacked_range_tree_divisor(self, device):
267266
def fn(x, num):
@@ -279,6 +278,7 @@ def fn(x, num):
279278
expected = fn(*example_inputs)
280279
torch.testing.assert_close(actual, expected)
281280

281+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
282282
@dynamo_config.patch({"capture_scalar_outputs": True})
283283
def test_unbacked_masked_scatter(self, device):
284284
def fn(value, mask):
@@ -294,6 +294,7 @@ def fn(value, mask):
294294
expected = fn(*example_inputs)
295295
torch.testing.assert_close(actual, expected)
296296

297+
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
297298
@dynamo_config.patch({"capture_scalar_outputs": True})
298299
@parametrize("dynamic", [False, True, None])
299300
def test_unbacked_slice_on_subclass(self, device, dynamic):
@@ -388,5 +389,4 @@ def fn(t, start, length):
388389
if __name__ == "__main__":
389390
from torch._inductor.test_case import run_tests
390391

391-
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu()):
392-
run_tests()
392+
run_tests()

test/test_fake_tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,17 @@ def f(x):
15971597
self.assertIsNot(u0, u1)
15981598
self.assertTrue(statically_known_true(u0 == u1))
15991599

1600+
def test_nonzero_stride(self):
1601+
shape_env = ShapeEnv()
1602+
fake_mode = FakeTensorMode(shape_env=shape_env)
1603+
with fake_mode:
1604+
value = torch.ones(5)
1605+
fake_r = value.nonzero()
1606+
1607+
r = torch.ones(5).nonzero()
1608+
1609+
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
1610+
16001611
def test_torch_load_with_fake_mode(self):
16011612
class TheModelClass(torch.nn.Module):
16021613
def __init__(self) -> None:

torch/_dynamo/debug_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,14 +678,16 @@ def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
678678
return v
679679

680680
def tensor(self, name, t) -> None:
681-
from torch.fx.experimental.symbolic_shapes import statically_known_true
681+
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
682682

683683
storage = self.storage(
684684
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
685685
)
686686
args = []
687687
# NB: this is positional, must come first
688-
if _stride_or_default(None, shape=t.shape) != t.stride():
688+
if not statically_known_true(
689+
sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
690+
):
689691
args.append(str(tuple(t.stride())))
690692
if _dtype_or_default(None) != t.dtype:
691693
args.append(f"dtype={t.dtype!r}")

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,8 @@ class FakifiedOutWrapper(CompilerWrapper):
520520
out_metas: list[torch.Tensor] = field(default_factory=list)
521521
# TracingContext.fwd_output_strides
522522
# Generated from actually doing compile
523-
fwd_output_strides: Optional[list[list[int]]] = None
523+
# NB: an entry is None if it's not a Tensor
524+
fwd_output_strides: Optional[list[Optional[list[int]]]] = None
524525
needs_post_compile: bool = True
525526

526527
def pre_compile(
@@ -551,12 +552,23 @@ def _compute_output_meta_with_inductor_strides(self):
551552
for i in range(len(out)):
552553
if not isinstance(out[i], Tensor):
553554
continue
555+
strides = fwd_output_strides[i]
556+
# fwd_output_strides is best effort by Inductor. When an output
557+
# Tensor has unbacked SymInts, Inductor may sometimes be unable
558+
# to compute what the output stride would be. If Inductor doesn't
559+
# have any clear direction on the layout, we don't have to run
560+
# as_strided. To repro without this, run:
561+
#
562+
# python test/distributed/test_dynamo_distributed.py
563+
# TestFakeDistributedSingleProc.test_unbacked_symbol_splitting_no_binding
564+
if strides is None:
565+
continue
554566
if all(
555567
statically_known_true(s1 == s2)
556-
for s1, s2 in zip(out[i].stride(), fwd_output_strides[i])
568+
for s1, s2 in zip(out[i].stride(), strides)
557569
):
558570
continue
559-
out[i] = out[i].as_strided(out[i].shape, fwd_output_strides[i])
571+
out[i] = out[i].as_strided(out[i].shape, strides)
560572
return out
561573

562574
# To be called post compile

torch/_prims_common/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,16 @@ def is_channels_last_contiguous_3d(a: Tensor) -> bool:
315315
if a.ndim != 5:
316316
return False
317317

318+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
319+
318320
expected_stride = 1
319321
for idx in (1, 4, 3, 2, 0):
320322
length = a.shape[idx]
321-
if length == 1:
323+
if guard_size_oblivious(length == 1):
322324
continue
323325

324326
stride = a.stride()[idx]
325-
if stride != expected_stride:
327+
if guard_size_oblivious(stride != expected_stride):
326328
return False
327329

328330
expected_stride *= length
@@ -436,7 +438,7 @@ def __eq__(self, other):
436438
if guard_size_oblivious(length == 1):
437439
continue
438440

439-
if stride != expected_stride:
441+
if guard_size_oblivious(stride != expected_stride):
440442
return False
441443

442444
expected_stride *= length

torch/_subclasses/fake_impls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def nonzero(fake_mode, func, arg):
463463

464464
arg.nonzero_memo = nnz
465465

466-
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
466+
return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64)
467467

468468

469469
@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default)

0 commit comments

Comments
 (0)