Skip to content

Commit 3f6ca86

Browse files
drisspgxuhancn
authored andcommitted
Fix IMAs in FlexAttention + autotuning (pytorch#130352)
# Summary Makes error message better for non divisible sequence lengths. Updates this PR was blocked due to two IMAs. - The first, is that when the kv indices ends up being an 'arange' I.e. there are non sparse blocks, we end up loading off of kv_indices + 1. - The second I dont really have a clear answer for. We were hitting an ima here: https://github.com/pytorch/pytorch/blob/9f401187c708d0f11122f0409254ddfc76befaf9/torch/_inductor/kernel/flex_attention.py#L846 I noticed that the for our inputs 2048 and q_blocksize = 128 we were again exactly at 16. Something felt fishy. I suspect we launch one extra sparse_q block, But why only during autotuning... ### Repro: https://gist.github.com/drisspg/f312a66426f3440b7756c6c0cc037f4c ### After this change: ``` ========= COMPUTE-SANITIZER AUTOTUNE flex_attention(1x1x2048x64, 1x1x2048x64, 1x1x2048x64, 1x1x2048, 1x1x16, 1x1x16x16) triton_flex_attention_0 2.1118 ms 100.0% BLOCK_DMODEL=64, BLOCK_M=128, BLOCK_N=128, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 triton_flex_attention_3 2.4306 ms 86.9% BLOCK_DMODEL=64, BLOCK_M=64, BLOCK_N=128, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 triton_flex_attention_1 2.5729 ms 82.1% BLOCK_DMODEL=64, BLOCK_M=128, BLOCK_N=64, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 triton_flex_attention_4 2.8035 ms 75.3% BLOCK_DMODEL=64, BLOCK_M=64, BLOCK_N=64, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 triton_flex_attention_2 2.8837 ms 73.2% BLOCK_DMODEL=64, BLOCK_M=128, BLOCK_N=128, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 0.7225 seconds and 1.5218 seconds precompiling AUTOTUNE flex_attention_backward(1x1x2048x64, 1x1x2048x64, 1x1x2048x64, 1x1x2048, 1x1x2048, 1x1x2048x64, 1x1x2048x64, 1x1x2048x64, 1x1x16, 1x1x16x16, 1x1x16, 1x1x16x16) triton_flex_attention_backward_30 2.7763 ms 100.0% BLOCK_DMODEL=64, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=4 triton_flex_attention_backward_15 3.1404 ms 88.4% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 triton_flex_attention_backward_14 3.2604 ms 85.2% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=4 triton_flex_attention_backward_7 3.4176 ms 81.2% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=3, num_warps=4 triton_flex_attention_backward_8 3.4182 ms 81.2% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=4, num_warps=4 triton_flex_attention_backward_34 3.4939 ms 79.5% BLOCK_DMODEL=64, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=8 triton_flex_attention_backward_6 3.6517 ms 76.0% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=32, BLOCK_N1=32, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=4 triton_flex_attention_backward_26 3.7000 ms 75.0% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=8 triton_flex_attention_backward_22 4.0120 ms 69.2% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=4 triton_flex_attention_backward_18 4.5052 ms 61.6% BLOCK_DMODEL=64, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, PRESCALE_QK=False, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, num_stages=1, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 6.6558 seconds and 6.3567 seconds precompiling torch.Size([1, 1, 2048, 64]) Test completed successfully! ========= ERROR SUMMARY: 0 errors ``` Pull Request resolved: pytorch#130352 Approved by: https://github.com/Skylion007, https://github.com/Chillee
1 parent 96aa07c commit 3f6ca86

File tree

3 files changed

+60
-35
lines changed

3 files changed

+60
-35
lines changed

test/inductor/test_flex_attention.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def func(q, k, v, score_mod, block_mask):
11761176
def test_aot_eager_gradcheck(self, score_mod):
11771177
make_tensor = functools.partial(
11781178
torch.randn,
1179-
(2, 2, 8, 4),
1179+
(2, 2, 128, 4),
11801180
device="cuda",
11811181
dtype=torch.float64,
11821182
requires_grad=True,
@@ -1199,7 +1199,7 @@ def test_captured_score_mod_aot_eager_gradcheck(
11991199
):
12001200
make_tensor = functools.partial(
12011201
torch.randn,
1202-
(2, 2, 8, 4),
1202+
(2, 2, 128, 4),
12031203
device="cuda",
12041204
dtype=torch.float64,
12051205
requires_grad=True,
@@ -1336,7 +1336,7 @@ def test_fw_bw_graph_correctness(self):
13361336
cnt = CompileCounterWithBackend("aot_eager")
13371337
make_tensor = functools.partial(
13381338
torch.randn,
1339-
(2, 2, 8, 4),
1339+
(2, 2, 128, 4),
13401340
device="cuda",
13411341
dtype=torch.float64,
13421342
requires_grad=True,
@@ -1355,7 +1355,7 @@ def test_fw_bw_graph_correctness(self):
13551355
norm_graph,
13561356
"""\
13571357
class GraphModule(torch.nn.Module):
1358-
def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_args_2_: "f64[2, 2, 8, 4]"):
1358+
def forward(self, L_args_0_: "f64[2, 2, 128, 4]", L_args_1_: "f64[2, 2, 128, 4]", L_args_2_: "f64[2, 2, 128, 4]"):
13591359
l_args_0_ = L_args_0_
13601360
l_args_1_ = L_args_1_
13611361
l_args_2_ = L_args_2_
@@ -1374,8 +1374,8 @@ def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_
13741374
child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
13751375
child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
13761376
flex_attention_0 = self.flex_attention_0
1377-
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8), 0.5); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
1378-
out: "f64[2, 2, 8, 4]" = flex_attention[0]; flex_attention = None
1377+
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 128, 128), 0.5); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
1378+
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
13791379
return (out,)
13801380
13811381
class GraphModule(torch.nn.Module):
@@ -1405,13 +1405,13 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
14051405
joint_graph,
14061406
"""\
14071407
class GraphModule(torch.nn.Module):
1408-
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
1408+
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 128, 4]", getitem_1: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
14091409
fw_graph = self.fw_graph
14101410
joint_graph = self.joint_graph
1411-
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8), 0.5); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
1412-
getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
1413-
getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
1414-
getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None
1411+
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 128, 128), 0.5); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
1412+
getitem_2: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
1413+
getitem_3: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
1414+
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
14151415
return [getitem_2, getitem_3, getitem_4]
14161416
14171417
class <lambda>(torch.nn.Module):
@@ -1429,6 +1429,29 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
14291429
""", # noqa: B950
14301430
)
14311431

1432+
@supported_platform
1433+
def test_nyi_for_non_divisible_seq_lens(self):
1434+
with self.assertRaisesRegex(
1435+
NotImplementedError, "NYI: L must be a multiple of 128"
1436+
):
1437+
flex_attention(
1438+
torch.randn((2, 3, 4)),
1439+
torch.randn((2, 10, 5)),
1440+
torch.randn((2, 10, 5)),
1441+
score_mod=_identity,
1442+
)
1443+
1444+
with self.assertRaisesRegex(
1445+
NotImplementedError, "NYI: L must be a multiple of 128"
1446+
):
1447+
compiled_flex = torch.compile(flex_attention)
1448+
compiled_flex(
1449+
torch.randn((2, 3, 4)),
1450+
torch.randn((2, 10, 5)),
1451+
torch.randn((2, 10, 5)),
1452+
score_mod=_identity,
1453+
)
1454+
14321455

14331456
common_utils.instantiate_parametrized_tests(TestFlexAttention)
14341457

torch/_inductor/kernel/flex_attention.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def convert_output_node_to_buffer(output):
283283
indices_idx = start_n // SPARSE_KV_MULTIPLE
284284
285285
cur_block = tl.load(kv_indices + indices_idx, eviction_policy="evict_last")
286-
next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last")
286+
next_block = tl.load(kv_indices + indices_idx + 1, eviction_policy="evict_last", mask=indices_idx + 1 < sparse_kv_num_blocks)
287287
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
288288
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N
289289
@@ -554,8 +554,8 @@ def flex_attention(*args, **kwargs):
554554
sparse_kv_indices,
555555
] + list(other_buffers)
556556
input_gen_fns = {
557-
4: create_num_blocks_fake_generator(sparse_kv_indices), # sparse_kv_num_blocks
558-
5: create_indices_fake, # sparse_kv_indices
557+
4: create_num_blocks_fake_generator(sparse_kv_indices),
558+
5: create_indices_fake,
559559
}
560560
return (
561561
autotune_select_algorithm(
@@ -779,7 +779,7 @@ def flex_attention_backward_grid(
779779
# Increment pointers.
780780
indices_idx = start_n // SPARSE_KV_MULTIPLE
781781
cur_block = tl.load(kv_indices + indices_idx)
782-
next_block = tl.load(kv_indices + indices_idx + 1)
782+
next_block = tl.load(kv_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_kv_num_blocks)
783783
needs_jump = (start_n + 1) % SPARSE_KV_MULTIPLE == 0
784784
jump_to_block = (next_block - cur_block ) * SPARSE_KV_BLOCK_SIZE - (SPARSE_KV_MULTIPLE - 1) * BLOCK_N2
785785
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_N2
@@ -886,7 +886,7 @@ def flex_attention_backward_grid(
886886
# Increment pointers.
887887
indices_idx = start_m // SPARSE_Q_MULTIPLE
888888
cur_block = tl.load(q_indices + indices_idx)
889-
next_block = tl.load(q_indices + indices_idx + 1)
889+
next_block = tl.load(q_indices + indices_idx + 1, mask=indices_idx + 1 < sparse_q_num_blocks)
890890
needs_jump = (start_m + 1) % SPARSE_Q_MULTIPLE == 0
891891
jump_to_block = (next_block - cur_block ) * SPARSE_Q_BLOCK_SIZE - (SPARSE_Q_MULTIPLE - 1) * BLOCK_M1
892892
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK_M1
@@ -999,13 +999,16 @@ def flex_attention_backward(*args, **kwargs):
999999
configs: List[Tuple[int, int, int, int]] = []
10001000
configs.append(_get_default_config_bwd(query))
10011001
if config.max_autotune:
1002-
for BLOCK1 in [32, 64]:
1003-
for BLOCK2 in [32, 64, 128]:
1004-
if BLOCK2 % BLOCK1 != 0:
1005-
continue
1006-
for w in [4, 8]:
1007-
for s in [1, 3, 4, 5]:
1008-
configs.append((BLOCK1, BLOCK2, w, s))
1002+
configs.extend(
1003+
[
1004+
(BLOCK1, BLOCK2, w, s)
1005+
for BLOCK1 in [32, 64]
1006+
for BLOCK2 in [32, 64, 128]
1007+
for w in [4, 8]
1008+
for s in [1, 3, 4, 5]
1009+
if BLOCK2 % BLOCK1 == 0
1010+
]
1011+
)
10091012

10101013
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
10111014
if (
@@ -1066,10 +1069,10 @@ def flex_attention_backward(*args, **kwargs):
10661069
sparse_q_indices,
10671070
] + list(other_buffers)
10681071
input_gen_fns = {
1069-
9: create_num_blocks_fake_generator(sparse_kv_indices), # sparse_kv_num_blocks
1070-
10: create_indices_fake,
1071-
11: create_num_blocks_fake_generator(sparse_q_indices), # sparse_q_num_blocks
1072-
12: create_indices_fake,
1072+
8: create_num_blocks_fake_generator(sparse_kv_indices),
1073+
9: create_indices_fake,
1074+
10: create_num_blocks_fake_generator(sparse_q_indices),
1075+
11: create_indices_fake,
10731076
}
10741077

10751078
grad_key = autotune_select_algorithm(

torch/nn/attention/_flex_attention.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,13 @@ def score_mod(
476476
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
477477
478478
"""
479+
# Some basic input validation
480+
_validate_sdpa_input(query, key, value)
481+
if query.size(-2) >= 32: # use Attention Kernel
482+
if query.size(-2) >= 128 & query.size(-2) % 128 != 0:
483+
raise NotImplementedError("NYI: S must be <128 or a multiple of 128")
484+
if key.size(-2) % 128 != 0:
485+
raise NotImplementedError("NYI: L must be a multiple of 128")
479486

480487
if block_mask is None:
481488
block_mask = _create_empty_block_mask(query, key, value)
@@ -490,14 +497,6 @@ def score_mod(
490497
)
491498
return out
492499

493-
# Some basic input validation
494-
_validate_sdpa_input(query, key, value)
495-
if query.size(-2) >= 32: # use Attention Kernel
496-
if query.size(-2) >= 128 & query.size(-2) % 128 != 0:
497-
raise ValueError("NYI: S must be <128 or a multiple of 128")
498-
if key.size(-2) % 128 != 0:
499-
raise ValueError("NYI: L must be a multiple of 128")
500-
501500
if not torch._dynamo.is_dynamo_supported():
502501
raise RuntimeError("flex_attention requires dynamo support.")
503502

0 commit comments

Comments
 (0)