Skip to content

Commit 002accf

Browse files
eellisonpytorchmergebot
authored andcommitted
Check meta strides for expanded dims in effn_attn_bias (pytorch#146054)
With the `_scaled_dot_product_efficient_attention.default`, we have lowering logic to realize the bias to specific alignment constraints. Some of the dims can be expanded, and we need to keep the stride of that dim to 0 to avoid materializing a larger tensor than we need. Previously, we had checked stride of tensor, but if it is not realized, that will not work. so we should check the strides of the meta as well. Note: getting the exact of realizing/slicing/requiring_exact_strides was a little tricky. I commented to @exclamaforte on an example unable-to-fuse message you get if you do it incorrectly. Fix for pytorch#145760 Pull Request resolved: pytorch#146054 Approved by: https://github.com/shunting314
1 parent 71e8a2b commit 002accf

File tree

5 files changed

+65
-8
lines changed

5 files changed

+65
-8
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,37 @@ def fn(
151151
# dont check rng state
152152
self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2])
153153

154+
def test_effn_attn_bias_padding_misaligned(self):
155+
seqlen_start = 1008
156+
157+
for offset in range(-1, 2):
158+
seqlen = seqlen_start + offset
159+
torch._dynamo.reset()
160+
161+
bsz = 32
162+
q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
163+
k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
164+
v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda")
165+
mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda")
166+
inputs = [q, k, v, mask]
167+
168+
def f(q, k, v, mask):
169+
return F.scaled_dot_product_attention(
170+
q, k, v, attn_mask=mask, dropout_p=0.0
171+
)
172+
173+
f_compiled = torch.compile(f)
174+
175+
out, code = run_and_get_code(f_compiled, *inputs)
176+
# padded bias should have an expanded dim
177+
FileCheck().check("buf0 =").check_same(", 0, ").run(code[0])
178+
# single fused padded kernel
179+
FileCheck().check("def call").check_count(
180+
"empty_strided_cuda", 1, exactly=True
181+
).check("return").run(code[0])
182+
183+
self.assertEqual(out, f(*inputs))
184+
154185
@skipIfRocm
155186
def test_input_channels_last(self):
156187
m = torch.nn.Sequential(

test/inductor_expected_failures/TestCommonCPU.test_out__refs_bitwise_not_cpu_int64

Whitespace-only changes.

test/inductor_expected_failures/TestCommonCUDA.test_out__refs_bitwise_not_cuda_int64

Whitespace-only changes.

torch/_inductor/ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5145,7 +5145,8 @@ def require_strides( # type: ignore[no-untyped-def]
51455145
allow_padding=False,
51465146
):
51475147
assert order is not None or exact_strides is not None
5148-
if x.get_numel() in (0, 1): # Layout doesn't matter
5148+
# Layout generally doesn't matter, but some consuming external ops might have requirements
5149+
if x.get_numel() in (0, 1) and not exact_strides:
51495150
return x
51505151

51515152
# require x to have the layout

torch/_inductor/lowering.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,12 +2490,15 @@ def apply_constraint(idx, arg, fx_arg):
24902490
out_size = list(arg.get_size())
24912491

24922492
expanded_dims = []
2493-
if arg.maybe_get_stride() is not None:
2494-
# We require a dense last dimension, but the other strides
2495-
# can be expanded, which results in a smaller tensor
2496-
for i, s in enumerate(arg.get_stride()[0:-1]):
2497-
if V.graph.sizevars.statically_known_equals(s, 0):
2498-
expanded_dims.append(i)
2493+
# We require a dense last dimension, but the other strides
2494+
# can be expanded, which results in a smaller tensor
2495+
maybe_stride = arg.maybe_get_stride()
2496+
for i in range(len(arg.get_size()) - 1):
2497+
if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
2498+
maybe_stride is not None
2499+
and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
2500+
):
2501+
expanded_dims.append(i)
24992502

25002503
# Now, pad strides to alignment
25012504
out_strides = [-1] * len(out_size)
@@ -2518,7 +2521,29 @@ def apply_constraint(idx, arg, fx_arg):
25182521
stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
25192522

25202523
out_strides[i] = stride
2521-
return ir.ExternKernel.require_exact_strides(arg, out_strides)
2524+
2525+
for dim in expanded_dims:
2526+
arg = slice_(arg, dim, 0, 1)
2527+
2528+
# TODO this is too subtle to get right in lowering, should be handled in match_exact_strides
2529+
out = ir.ExternKernel.require_exact_strides(arg, out_strides)
2530+
out = expand(TensorBox(out), out_size)
2531+
out = ir.try_match_insignificant_strides(out, out_strides)
2532+
return out
2533+
2534+
if ir.is_aligned_realized_tensor(arg, ALIGNMENT):
2535+
return ir.try_match_insignificant_strides(
2536+
ir.ExternKernel.realize_input(arg), meta_stride_expr
2537+
)
2538+
2539+
if (
2540+
isinstance(arg, IRNode)
2541+
and arg.maybe_get_stride() is not None
2542+
and ir.is_aligned_realized_tensor(arg, ALIGNMENT)
2543+
):
2544+
return ir.try_match_insignificant_strides(
2545+
ir.ExternKernel.realize_input(arg), meta_stride_expr
2546+
)
25222547

25232548
def is_aligned(x):
25242549
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0

0 commit comments

Comments
 (0)