Skip to content

Commit 2b43d33

Browse files
drisspgpytorchmergebot
authored andcommitted
Make FlexAttention API public (pytorch#130755)
# Summary Makes the prototype API flex_attention public Pull Request resolved: pytorch#130755 Approved by: https://github.com/Chillee
1 parent cbda8be commit 2b43d33

File tree

7 files changed

+98
-53
lines changed

7 files changed

+98
-53
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
======================================
5+
torch.nn.attention.flex_attention
6+
======================================
7+
8+
.. currentmodule:: torch.nn.attention.flex_attention
9+
.. py:module:: torch.nn.attention.flex_attention
10+
.. autofunction:: flex_attention
11+
12+
BlockMask Utilities
13+
-------------------
14+
15+
.. autofunction:: create_block_mask
16+
.. autofunction:: create_mask
17+
18+
BlockMask
19+
---------
20+
21+
.. autoclass:: BlockMask
22+
:members:
23+
:undoc-members:

docs/source/nn.attention.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ Submodules
2020
.. autosummary::
2121
:nosignatures:
2222

23+
flex_attention
2324
bias
2425

2526
.. toctree::
2627
:hidden:
2728

29+
nn.attention.flex_attention
2830
nn.attention.bias

test/inductor/test_flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch._inductor import metrics
1717
from torch._inductor.test_case import TestCase as InductorTestCase
1818
from torch._inductor.utils import run_and_get_code
19-
from torch.nn.attention._flex_attention import (
19+
from torch.nn.attention.flex_attention import (
2020
_causal,
2121
_compose,
2222
_create_empty_block_mask,

test/inductor/test_flex_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
1414
from torch._inductor.test_case import TestCase as InductorTestCase
1515
from torch._inductor.utils import run_and_get_code
16-
from torch.nn.attention._flex_attention import (
16+
from torch.nn.attention.flex_attention import (
1717
_causal,
1818
_compose,
1919
_create_empty_block_mask,

torch/_higher_order_ops/flex_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _math_attention_inner(
136136
n = torch.arange(0, scores.size(3), device=scores.device)
137137

138138
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
139-
from torch.nn.attention._flex_attention import _vmap_for_bhqkv
139+
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
140140

141141
# first input is score
142142
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
@@ -690,7 +690,7 @@ def sdpa_dense_backward(
690690
# Gradient of the inline score_mod function, with respect to the scores
691691
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
692692
out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
693-
from torch.nn.attention._flex_attention import _vmap_for_bhqkv
693+
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
694694

695695
# inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
696696
# score and gradOut are "fully" batched

torch/nn/attention/_flex_attention.py renamed to torch/nn/attention/flex_attention.py

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,35 @@ def inner(score, b, h, m, n):
4242
_mask_fn_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
4343

4444

45-
class ModificationType(Enum):
45+
class _ModificationType(Enum):
46+
"""Enum for the type of modification function.
47+
- SCORE_MOD: score_mod function which accepts a score as the first argument
48+
- MASK_FN: mask function which does not accept a score and is only used for generating
49+
block mask
50+
"""
51+
4652
SCORE_MOD = 1
4753
MASK_FN = 2
4854

4955

5056
@torch._dynamo.assume_constant_result
51-
def get_mod_type(fn) -> ModificationType:
57+
def _get_mod_type(fn: Callable) -> _ModificationType:
58+
"""Get the type of modification function.
59+
This function inspects the number of positional arguments of the function to determine
60+
the type of modification function. If the function has 5 positional arguments, it is
61+
considered as a score_mod function. If the function has 4 positional arguments, it is
62+
considered as a mask function.
63+
"""
5264
num_positional_args = sum(
5365
1
5466
for param in inspect.signature(fn).parameters.values()
5567
if param.default == inspect.Parameter.empty
5668
)
5769
assert num_positional_args == 5 or num_positional_args == 4
5870
if num_positional_args == 5:
59-
return ModificationType.SCORE_MOD
71+
return _ModificationType.SCORE_MOD
6072
elif num_positional_args == 4:
61-
return ModificationType.MASK_FN
73+
return _ModificationType.MASK_FN
6274
else:
6375
raise AssertionError
6476

@@ -114,51 +126,59 @@ class BlockMask:
114126
BlockMask is our format for representing a block-sparse attention mask.
115127
It is somewhat of a cross in-between BCSR and a non-sparse format.
116128
117-
## Basics
129+
Basics
130+
------
118131
A block-sparse mask means that instead of representing the sparsity of
119-
individual elements in the mask, we only consider a block sparse if an
120-
entire KV_BLOCK_SIZE x Q_BLOCK_SIZE is sparse. This aligns well with
121-
hardware, which generally expects to perform contiguous loads and
122-
computation.
132+
individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is
133+
considered sparse only if every element within that block is sparse.
134+
This aligns well with hardware, which generally expects to perform
135+
contiguous loads and computation.
123136
124137
This format is primarily optimized for 1. simplicity, and 2. kernel
125138
efficiency. Notably, it is *not* optimized for size, as we believe the mask
126139
is sufficiently small that its size is not a concern.
127140
128141
The essentials of our format are:
129-
num_blocks_in_row: Tensor[ROWS] # Describes the number of blocks present in
130-
each row.
131-
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL] # col_indices[i] is the
132-
position of the blocks in index i. The values of this row after
133-
col_indices[i][num_blocks_in_row[i]] are undefined.
134-
135-
For example, to reconstruct the original tensor from this format.
136-
```
137-
dense_mask = torch.zeros(ROWS, COLS)
138-
for row in range(ROWS):
139-
for block_idx in range(num_blocks_in_row[row]):
140-
dense_mask[row, col_indices[row, block_idx]] = 1
141-
```
142+
143+
- num_blocks_in_row: Tensor[ROWS]
144+
Describes the number of blocks present in each row.
145+
146+
- col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]
147+
`col_indices[i]` is the sequence of block positions for row i. The values of
148+
this row after `col_indices[i][num_blocks_in_row[i]]` are undefined.
149+
150+
For example, to reconstruct the original tensor from this format:
151+
152+
.. code-block:: python
153+
154+
dense_mask = torch.zeros(ROWS, COLS)
155+
for row in range(ROWS):
156+
for block_idx in range(num_blocks_in_row[row]):
157+
dense_mask[row, col_indices[row, block_idx]] = 1
142158
143159
Notably, this format makes it easier to implement a reduction along the
144160
*rows* of the mask.
145161
146-
## Details
147-
The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs:
148-
149-
(kv_num_blocks, kv_indices): This is used for the forwards pass of
150-
attention, as we reduce along the KV dimension.
151-
(q_num_blocks, q_indices): This is required for the backwards pass, as
152-
computing dKV requires iterating along the mask along the Q dimension.
153-
[OPTIONAL](full_kv_num_blocks, full_kv_indices): This is optional, and is
154-
purely an optimization. As it turns out, applying masking to every block is
155-
quite expensive! If we specifically know which blocks are "full" and don't
156-
require masking at all, then we can skip applying mask_mod to these blocks.
157-
This requires the user to split out a separate mask_mod from the score_mod.
158-
For causal masks, this is about a 15% speedup.
159-
[OPTIONAL](full_q_num_blocks, full_q_indices): Same as above, but for the
160-
backwards.
162+
Details
163+
-------
164+
The basics of our format require only kv_num_blocks and kv_indices. But, we
165+
have up to 8 tensors on this object. This represents 4 pairs:
166+
167+
1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as
168+
we reduce along the KV dimension.
169+
170+
2. (q_num_blocks, q_indices): Required for the backwards pass, as computing
171+
dKV requires iterating along the mask along the Q dimension.
172+
173+
3. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and
174+
purely an optimization. As it turns out, applying masking to every block
175+
is quite expensive! If we specifically know which blocks are "full" and
176+
don't require masking at all, then we can skip applying mask_mod to these
177+
blocks. This requires the user to split out a separate mask_mod from the
178+
score_mod. For causal masks, this is about a 15% speedup.
161179
180+
4. [OPTIONAL] (full_q_num_blocks, full_q_indices): Same as above, but for
181+
the backwards pass.
162182
"""
163183
kv_num_blocks: Tensor
164184
kv_indices: Tensor
@@ -184,7 +204,7 @@ def __init__(
184204
full_q_indices: Optional[Tensor],
185205
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
186206
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
187-
mask_fn=None,
207+
mask_fn: Optional[_mask_fn_signature] = None,
188208
):
189209
if kv_indices.dim() < 2:
190210
raise RuntimeError("BlockMask must have at least 2 dimensions")
@@ -469,7 +489,7 @@ def create_mask(
469489
r"""This function creates a mask tensor from a mod_fn function.
470490
471491
Args:
472-
mod_fn (Callable): Function to modify attention scores.
492+
mod_fn (Union[_score_mod_signature, _mask_fn_signature]): Function to modify attention scores.
473493
B (int): Batch size.
474494
H (int): Number of heads.
475495
M (int): Sequence length of query.
@@ -491,16 +511,16 @@ def create_mask(
491511
ctx = nullcontext()
492512
else:
493513
ctx = TransformGetItemToIndex() # type: ignore[assignment]
494-
mod_type = get_mod_type(mod_fn)
514+
mod_type = _get_mod_type(mod_fn)
495515

496516
with ctx:
497-
if mod_type == ModificationType.SCORE_MOD:
517+
if mod_type == _ModificationType.SCORE_MOD:
498518
score_mod = mod_fn
499519
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score
500520
out = score_mod(torch.zeros(B, H, M, N, device=device), b, h, m, n)
501521
mask = torch.where(torch.isneginf(out), False, True)
502522
return mask
503-
elif mod_type == ModificationType.MASK_FN:
523+
elif mod_type == _ModificationType.MASK_FN:
504524
mask_fn = mod_fn
505525
mask_fn = _vmap_for_bhqkv(mask_fn, prefix=())
506526
mask = mask_fn(b, h, m, n)
@@ -515,8 +535,8 @@ def _create_block_mask_inner(
515535
mod_fn, B, H, M, N, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE, mod_type
516536
):
517537
mask_tensor = create_mask(mod_fn, B, H, M, N, device, _compile=True)
518-
mod_type = get_mod_type(mod_fn)
519-
if mod_type == ModificationType.MASK_FN:
538+
mod_type = _get_mod_type(mod_fn)
539+
if mod_type == _ModificationType.MASK_FN:
520540
mask_fn = mod_fn
521541
else:
522542
mask_fn = None
@@ -558,7 +578,7 @@ def create_block_mask(
558578
block_mask (tuple): A tuple of (kv_num_blocks, kv_indices, q_num_blocks, q_indices,
559579
KV_BLOCK_SIZE, Q_BLOCK_SIZE) which represents the block mask.
560580
"""
561-
mod_type = get_mod_type(fn)
581+
mod_type = _get_mod_type(fn)
562582
inner_func = _create_block_mask_inner
563583
# This is kind of a temporary hack to workaround some issues
564584
if _compile:
@@ -618,14 +638,14 @@ def score_mod(
618638
score: Tensor,
619639
batch: Tensor,
620640
head: Tensor,
621-
token_q: Tensor,
622-
token_kv: Tensor
641+
q_idx: Tensor,
642+
kv_idx: Tensor
623643
) -> Tensor:
624644
625645
Where:
626646
- ``score``: A scalar tensor representing the attention score,
627647
with the same data type and device as the query, key, and value tensors.
628-
- ``b``, ``h``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
648+
- ``batch``, ``head``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
629649
the batch index, head index, query index, and key/value index, respectively.
630650
These should have the ``torch.int`` data type and be located on the same device as the score tensor.
631651

torch/testing/_internal/hop_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from torch.testing._internal.common_dtype import all_types_and, custom_types
1313
from torch.testing._internal.opinfo.core import DecorateInfo
14-
from torch.nn.attention._flex_attention import flex_attention, _create_empty_block_mask
14+
from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask
1515

1616
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
1717
make_arg = functools.partial(

0 commit comments

Comments
 (0)