Skip to content

Commit 39eeaac

Browse files
bdhirshpytorchmergebot
authored andcommitted
inductor: avoiding moving constructor to cuda when it would cause h2d sync in index_put_ fallback (pytorch#130338)
My attempt at a fix for pytorch#130335, see issue for more details / internal xref. Any feedback from inductor folks is appreciated. I attempted to make the move-constructors-to-cuda pass a bit less aggressive by detecting when the movement would incur a H2D sync for `aten.index_put_`. I'm not sure if there are any other ops that inductor falls back to eager on, that may-or-may-not incur a H2D sync if we change any of their inputs from cpu to cuda. Pull Request resolved: pytorch#130338 Approved by: https://github.com/eellison
1 parent 93a03ed commit 39eeaac

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11436,6 +11436,27 @@ def fn():
1143611436

1143711437
self.assertEqual(fn_opt(), fn())
1143811438

11439+
# https://github.com/pytorch/pytorch/issues/130335
11440+
def test_ctr_not_moved_to_cuda_when_used_in_index_put(self):
11441+
@torch.compile
11442+
def f(x, mask):
11443+
x[:, mask] = -math.inf
11444+
return x
11445+
11446+
x_tmp = torch.randn(512, 19, device="cuda")
11447+
x = x_tmp.permute(1, 0).view(-1, 128, 4)[:, :, 1:]
11448+
11449+
mask_tmp = torch.ones(128, 3, dtype=torch.int32, device="cuda")
11450+
mask = mask_tmp == mask_tmp
11451+
f(x, mask)
11452+
code = run_and_get_triton_code(f, x, mask)
11453+
# What we are testing here:
11454+
# inductor has a pass to move tensor constructors on cpu to cuda
11455+
# (the -math.inf will become a scalar-tensor input to index_put_())
11456+
# we are asserting that when inductor allocates this tensor,
11457+
# it does not move the tensor constructor to cuda and keeps it on CPU.
11458+
self.assertFalse("empty_strided_cuda(()" in code)
11459+
1143911460
@config.patch("triton.use_block_ptr", False)
1144011461
def test_evict_last_non_coalesced_loads(self):
1144111462
@torch.compile

torch/_inductor/fx_passes/post_grad.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,35 @@ def fused_int_mm_mul(match: Match, mat1, mat2, mat3, out_dtype=None):
10151015
return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
10161016

10171017

1018+
def is_index_put_and_requires_h2d_sync_for_cuda_value(node):
1019+
from torch.fx.operator_schemas import normalize_function
1020+
1021+
if node.target not in [
1022+
torch.ops.aten.index_put.default,
1023+
torch.ops.aten.index_put_.default,
1024+
]:
1025+
return False
1026+
# Inductor falls back to aten.index_put_.
1027+
# index_put_ will will call nonzero() and perform a H2D sync if
1028+
# any of its indices are bool/byte tensors
1029+
# However, it will short-circuit this H2D sync and run mask_fill_
1030+
# if the value we are putting is a cpu scalar.
1031+
# Therefore, when inductor sees an index_put_ with byte tensor indices,
1032+
# it should *not* convert the cpu scalar value into a cuda tensor.
1033+
args_, kwargs_ = normalize_function(node.target, node.args, node.kwargs)
1034+
any_byte_bool_indices = False
1035+
indices = args_[1]
1036+
for i in indices:
1037+
if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]:
1038+
any_byte_bool_indices = True
1039+
1040+
val = args_[2].meta["val"]
1041+
val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1
1042+
# If both these conditions hold, then converting the val
1043+
# to a cuda tensor will incur a H2D sync when inductor calls aten.index_put_
1044+
return any_byte_bool_indices and val_is_cpu_scalar
1045+
1046+
10181047
class ConstructorMoverPass:
10191048
def __init__(self, target: str, allow_outputs: bool = False) -> None:
10201049
"""
@@ -1068,6 +1097,8 @@ def cannot_be_moved(self, node: fx.Node) -> bool:
10681097
and node.target.namespace in ("prims", "aten")
10691098
):
10701099
return True
1100+
if is_index_put_and_requires_h2d_sync_for_cuda_value(node):
1101+
return True
10711102

10721103
return False
10731104

0 commit comments

Comments
 (0)