Skip to content

Commit 795f28a

Browse files
Chilleepytorchmergebot
authored andcommitted
Ensure that BlockMask length must always exactly match the sequence length in flex_attention (pytorch#141625)
Fixes pytorch#141435 Pull Request resolved: pytorch#141625 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#138788
1 parent 8eb259f commit 795f28a

File tree

7 files changed

+171
-118
lines changed

7 files changed

+171
-118
lines changed

test/inductor/test_flex_attention.py

Lines changed: 85 additions & 64 deletions
Large diffs are not rendered by default.

test/inductor/test_flex_decoding.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,9 @@ def run_test_with_call_paged_attention(
579579
ref_out = golden_call(q_ref, k_ref, v_ref)
580580

581581
if mask_mod is not None:
582-
block_mask = create_block_mask(mask_mod, Q_B, 1, 1, S)
582+
block_mask = create_block_mask(mask_mod, Q_B, 1, Q_S, KV_S)
583583
else:
584-
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, S)
584+
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S)
585585

586586
compiled_out, _ = self.run_paged_attention(
587587
score_mod, q, k, v, dtype, block_mask
@@ -682,7 +682,7 @@ def test_builtin_score_mods_different_block_size(
682682
score_mod: Callable,
683683
BLOCK_SIZE: Union[int, Tuple[int, int]],
684684
):
685-
block_mask = create_block_mask(noop_mask, B, 1, S, S, BLOCK_SIZE=BLOCK_SIZE)
685+
block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE)
686686
self.run_test(score_mod, dtype, block_mask=block_mask)
687687

688688
def input_strides_1(B, H, S, D):
@@ -1098,7 +1098,7 @@ def scoremod_1(qk, b, h, q, kv):
10981098
def scoremod_2(qk, b, h, q, kv):
10991099
return torch.where(q >= kv, qk, -float("inf"))
11001100

1101-
block_mask = create_block_mask(noop_mask, 1, 1, 1, S)
1101+
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024)
11021102

11031103
def f(q, k1, k2, v1, v2):
11041104
q2 = flex_attention(q, k1, v1, score_mod=scoremod_1, block_mask=block_mask)
@@ -1167,7 +1167,7 @@ def scoremod_1(qk, b, h, q, kv):
11671167
def scoremod_2(qk, b, h, q, kv):
11681168
return torch.where(q >= kv, qk, -float("inf"))
11691169

1170-
block_mask = create_block_mask(noop_mask, 1, 1, 1, S)
1170+
block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024)
11711171

11721172
attention1 = functools.partial(
11731173
flex_attention, score_mod=scoremod_1, block_mask=block_mask
@@ -1567,8 +1567,8 @@ def mask_mod(b, h, q_idx, kv_idx):
15671567
mask_mod=mask_mod,
15681568
B=2,
15691569
H=None,
1570-
Q_LEN=128,
1571-
KV_LEN=256,
1570+
Q_LEN=2,
1571+
KV_LEN=2,
15721572
device="cuda",
15731573
)
15741574

torch/_higher_order_ops/flex_attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def forward(
614614
value,
615615
out,
616616
logsumexp,
617-
*block_mask[:10],
617+
*block_mask[:-1],
618618
*score_mod_other_buffers,
619619
*mask_mod_other_buffers,
620620
),
@@ -630,6 +630,8 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option
630630
value,
631631
out,
632632
logsumexp,
633+
query_lengths,
634+
kv_lengths,
633635
kv_num_blocks,
634636
kv_indices,
635637
full_kv_num_blocks,
@@ -672,6 +674,8 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option
672674
fw_graph,
673675
joint_graph,
674676
(
677+
query_lengths,
678+
kv_lengths,
675679
kv_num_blocks,
676680
kv_indices,
677681
full_kv_num_blocks,
@@ -708,7 +712,8 @@ def flex_attention_autograd(
708712

709713
with TransformGetItemToIndex():
710714
input_requires_grad = any(
711-
t.requires_grad for t in (query, key, value, *score_mod_other_buffers)
715+
isinstance(t, torch.Tensor) and t.requires_grad
716+
for t in (query, key, value, *score_mod_other_buffers)
712717
)
713718
if torch.is_grad_enabled() and input_requires_grad:
714719
example_vals = (
@@ -1130,7 +1135,9 @@ def flex_attention_backward_fake_tensor_mode(
11301135
grad_value = torch.empty_like(value)
11311136
grad_score_mod_captured = tuple(
11321137
[
1133-
torch.empty_like(buffer) if buffer.requires_grad else None
1138+
torch.empty_like(buffer)
1139+
if isinstance(buffer, torch.Tensor) and buffer.requires_grad
1140+
else None
11341141
for buffer in score_mod_other_buffers
11351142
]
11361143
)

torch/_inductor/kernel/flex_attention.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,8 @@ def flex_attention(
810810
mask_mod_other_buffers,
811811
):
812812
(
813+
_, # q_length
814+
_, # kv_length
813815
kv_num_blocks,
814816
kv_indices,
815817
full_kv_num_blocks,
@@ -968,12 +970,6 @@ def flex_attention(
968970
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
969971
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
970972
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
971-
assert V.graph.sizevars.evaluate_expr(
972-
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
973-
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
974-
assert V.graph.sizevars.evaluate_expr(
975-
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
976-
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
977973

978974
# Note, we don't need to pass in the captured buffers explicitly
979975
# because they're implicitly added by the score_mod function
@@ -1509,7 +1505,7 @@ def bwd_dq_block_mn(
15091505
) | indent_except_first(2) }}
15101506
15111507
if CHECK_BLOCK_BOUNDARY:
1512-
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
1508+
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
15131509
# apply mask for partial masked block
15141510
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
15151511
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1541,7 +1537,7 @@ def bwd_dq_block_mn(
15411537
15421538
if not IS_FULL_BLOCKS:
15431539
if CHECK_BLOCK_BOUNDARY:
1544-
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
1540+
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
15451541
# (grads) apply mask for partially unmasked block
15461542
ds = tl.where(mask_mod_output, ds, 0.0)
15471543
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1691,7 +1687,7 @@ def bwd_dkdv_block_mn(
16911687
n="n",
16921688
) | indent_except_first(2) }}
16931689
if CHECK_BLOCK_BOUNDARY:
1694-
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
1690+
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
16951691
# (grads) apply mask for fully masked block
16961692
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
16971693
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1749,7 +1745,7 @@ def bwd_dkdv_block_mn(
17491745
dsT = grad_scores
17501746
if not IS_FULL_BLOCKS:
17511747
if CHECK_BLOCK_BOUNDARY:
1752-
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
1748+
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
17531749
# (grads) apply mask for partially unmasked block
17541750
dsT = tl.where(mask_mod_output, dsT, 0.0)
17551751
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1860,6 +1856,8 @@ def flex_attention_backward(*args, **kwargs):
18601856
mask_mod_other_buffers,
18611857
) = args
18621858
(
1859+
_, # q_length
1860+
_, # kv_length
18631861
kv_num_blocks,
18641862
kv_indices,
18651863
full_kv_num_blocks,
@@ -2036,6 +2034,9 @@ def flex_attention_backward(*args, **kwargs):
20362034
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
20372035
):
20382036
continue
2037+
if num_warps == 8:
2038+
# Working around https://github.com/pytorch/pytorch/issues/141603
2039+
continue
20392040

20402041
# Performance tuning
20412042
cur_kernel_options = original_kernel_options.copy()

torch/_inductor/kernel/flex_decoding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
332332
mask_mod_other_buffers,
333333
) = args
334334
(
335+
_, # q_length
336+
_, # kv_length
335337
kv_num_blocks,
336338
kv_indices,
337339
full_kv_num_blocks, # full_kv_num_blocks,

torch/nn/attention/experimental/_paged_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,15 @@ def convert_logical_block_mask(
264264

265265
new_mask_mod = self.get_mask_mod(block_mask.mask_mod)
266266

267+
seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
267268
return BlockMask.from_kv_blocks(
268269
new_kv_num_blocks,
269270
new_kv_indices,
270271
new_full_kv_num_blocks,
271272
new_full_kv_indices,
272273
block_mask.BLOCK_SIZE,
273274
new_mask_mod,
275+
seq_lengths=seq_lengths,
274276
)
275277

276278
def get_mask_mod(

torch/nn/attention/flex_attention.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class BlockMask:
262262
the backwards pass. These are autogenerated from 2.
263263
"""
264264

265+
seq_lengths: Tuple[int, int]
265266
kv_num_blocks: Tensor
266267
kv_indices: Tensor
267268
full_kv_num_blocks: Optional[Tensor]
@@ -275,6 +276,7 @@ class BlockMask:
275276

276277
def __init__(
277278
self,
279+
seq_lengths: Tuple[int, int],
278280
kv_num_blocks: Tensor,
279281
kv_indices: Tensor,
280282
full_kv_num_blocks: Optional[Tensor],
@@ -299,6 +301,7 @@ def __init__(
299301
full_q_indices is None
300302
), "full_q_num_blocks and full_q_indices must be both provided or omitted"
301303

304+
self.seq_lengths = seq_lengths
302305
self.kv_num_blocks = kv_num_blocks
303306
self.kv_indices = kv_indices
304307
self.full_kv_num_blocks = full_kv_num_blocks
@@ -319,6 +322,7 @@ def from_kv_blocks(
319322
full_kv_indices: Optional[Tensor] = None,
320323
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
321324
mask_mod: Optional[_mask_mod_signature] = None,
325+
seq_lengths: Optional[Tuple[int, int]] = None,
322326
):
323327
"""
324328
Creates a BlockMask instance from key-value block information.
@@ -359,8 +363,13 @@ def from_kv_blocks(
359363
BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE)
360364

361365
mask_mod = mask_mod if mask_mod is not None else noop_mask
366+
if seq_lengths is None:
367+
q_length = kv_indices.shape[-2] * BLOCK_SIZE[0]
368+
kv_length = q_indices.shape[-2] * BLOCK_SIZE[1]
369+
seq_lengths = (q_length, kv_length)
362370

363371
return cls(
372+
seq_lengths=seq_lengths,
364373
kv_num_blocks=kv_num_blocks,
365374
kv_indices=kv_indices,
366375
full_kv_num_blocks=full_kv_num_blocks,
@@ -380,11 +389,15 @@ def as_tuple(self, flatten: bool = True):
380389
Args:
381390
flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
382391
"""
383-
block_size = (
384-
(self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) if flatten else (self.BLOCK_SIZE,)
385-
)
392+
if flatten:
393+
block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment]
394+
seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment]
395+
else:
396+
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
397+
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
386398

387399
return (
400+
*seq_lengths,
388401
self.kv_num_blocks,
389402
self.kv_indices,
390403
self.full_kv_num_blocks,
@@ -397,6 +410,11 @@ def as_tuple(self, flatten: bool = True):
397410
self.mask_mod,
398411
)
399412

413+
@property
414+
def shape(self):
415+
*batch_dims, _, _ = self.kv_indices.shape
416+
return tuple(batch_dims) + self.seq_lengths
417+
400418
def __str__(self):
401419
s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n"
402420
mask_str = self.to_string().strip()
@@ -457,6 +475,7 @@ def causal_mask(b, h, q_idx, kv_idx):
457475
new_full_kv_indices,
458476
BLOCK_SIZE=self.BLOCK_SIZE,
459477
mask_mod=None,
478+
seq_lengths=self.seq_lengths,
460479
)
461480

462481
def __repr__(self):
@@ -509,14 +528,6 @@ def _adjust(self, new_q_len: int, new_kv_len: int):
509528
self.mask_mod,
510529
)
511530

512-
@property
513-
def shape(self):
514-
"""Returns the shape of the mask."""
515-
*batch_dims, q_length, _ = self.kv_indices.shape
516-
q_length = self.kv_indices.shape[-2] * self.BLOCK_SIZE[0]
517-
kv_length = self.kv_indices.shape[-1] * self.BLOCK_SIZE[1]
518-
return tuple(batch_dims + [q_length, kv_length])
519-
520531
def numel(self):
521532
"""Returns the number of elements (not accounting for sparsity) in the mask."""
522533
shape = self.shape
@@ -739,6 +750,7 @@ def _convert_block_mask_to_mask(
739750
def _create_sparse_block_from_block_mask(
740751
block_mask: Tuple[Tensor, Optional[Tensor]],
741752
mask_mod: Optional[Callable],
753+
seq_lengths: Tuple[int, int],
742754
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
743755
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
744756
) -> BlockMask:
@@ -757,6 +769,7 @@ def _create_sparse_block_from_block_mask(
757769
full_bm[1],
758770
BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE),
759771
mask_mod=mask_mod,
772+
seq_lengths=seq_lengths,
760773
)
761774

762775

@@ -878,7 +891,11 @@ def causal_mask(b, h, q_idx, kv_idx):
878891
separate_full_blocks=True,
879892
)
880893
block_mask = _create_sparse_block_from_block_mask(
881-
(partial_block_mask, full_block_mask), mask_mod, Q_BLOCK_SIZE, KV_BLOCK_SIZE
894+
(partial_block_mask, full_block_mask),
895+
mask_mod,
896+
(Q_LEN, KV_LEN),
897+
Q_BLOCK_SIZE,
898+
KV_BLOCK_SIZE,
882899
)
883900
return block_mask
884901

@@ -894,6 +911,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
894911
kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device),
895912
kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device),
896913
BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE,
914+
seq_lengths=(1, 1),
897915
)
898916

899917

@@ -1237,29 +1255,31 @@ def score_mod(
12371255

12381256
if block_mask is None:
12391257
block_mask = _create_empty_block_mask(query, key)
1240-
elif (
1241-
not query.is_nested
1242-
and (query.requires_grad or key.requires_grad or value.requires_grad)
1243-
and (
1244-
query.size(-2)
1245-
< block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
1246-
or key.size(-2) < block_mask.kv_indices.size(-1) * block_mask.BLOCK_SIZE[1]
1247-
)
1248-
):
1249-
new_q_len = _round_up_to_multiple(query.size(-2), block_mask.BLOCK_SIZE[0])
1250-
new_kv_len = _round_up_to_multiple(key.size(-2), block_mask.BLOCK_SIZE[1])
1251-
block_mask = block_mask._adjust(new_q_len, new_kv_len)
1252-
elif query.is_nested and (
1253-
block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
1254-
!= _round_up_to_multiple(
1255-
query._values.size(query._ragged_idx - 1), block_mask.BLOCK_SIZE[0] # type: ignore[attr-defined]
1256-
)
1258+
1259+
if (
1260+
block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
1261+
and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE
12571262
):
1258-
# TODO: Maybe we want to auto-adjust for this case as well?
1259-
raise RuntimeError(
1260-
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
1261-
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
1262-
)
1263+
# This corresponds to the case where we essentially have a "no-op" block mask.
1264+
pass
1265+
else:
1266+
block_mask_q_len = block_mask.shape[-2]
1267+
block_mask_kv_len = block_mask.shape[-1]
1268+
if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len:
1269+
raise ValueError(
1270+
f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
1271+
"As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask."
1272+
)
1273+
elif (
1274+
query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len
1275+
) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len):
1276+
raise ValueError(
1277+
f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. "
1278+
"As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!"
1279+
)
1280+
assert query.size(-2) == block_mask_q_len
1281+
assert key.size(-2) == block_mask_kv_len
1282+
12631283
if scale is None:
12641284
scale = 1.0 / math.sqrt(query.size(-1))
12651285

0 commit comments

Comments
 (0)