@@ -42,23 +42,35 @@ def inner(score, b, h, m, n):
42
42
_mask_fn_signature = Callable [[Tensor , Tensor , Tensor , Tensor ], Tensor ]
43
43
44
44
45
- class ModificationType (Enum ):
45
+ class _ModificationType (Enum ):
46
+ """Enum for the type of modification function.
47
+ - SCORE_MOD: score_mod function which accepts a score as the first argument
48
+ - MASK_FN: mask function which does not accept a score and is only used for generating
49
+ block mask
50
+ """
51
+
46
52
SCORE_MOD = 1
47
53
MASK_FN = 2
48
54
49
55
50
56
@torch ._dynamo .assume_constant_result
51
- def get_mod_type (fn ) -> ModificationType :
57
+ def _get_mod_type (fn : Callable ) -> _ModificationType :
58
+ """Get the type of modification function.
59
+ This function inspects the number of positional arguments of the function to determine
60
+ the type of modification function. If the function has 5 positional arguments, it is
61
+ considered as a score_mod function. If the function has 4 positional arguments, it is
62
+ considered as a mask function.
63
+ """
52
64
num_positional_args = sum (
53
65
1
54
66
for param in inspect .signature (fn ).parameters .values ()
55
67
if param .default == inspect .Parameter .empty
56
68
)
57
69
assert num_positional_args == 5 or num_positional_args == 4
58
70
if num_positional_args == 5 :
59
- return ModificationType .SCORE_MOD
71
+ return _ModificationType .SCORE_MOD
60
72
elif num_positional_args == 4 :
61
- return ModificationType .MASK_FN
73
+ return _ModificationType .MASK_FN
62
74
else :
63
75
raise AssertionError
64
76
@@ -114,51 +126,59 @@ class BlockMask:
114
126
BlockMask is our format for representing a block-sparse attention mask.
115
127
It is somewhat of a cross in-between BCSR and a non-sparse format.
116
128
117
- ## Basics
129
+ Basics
130
+ ------
118
131
A block-sparse mask means that instead of representing the sparsity of
119
- individual elements in the mask, we only consider a block sparse if an
120
- entire KV_BLOCK_SIZE x Q_BLOCK_SIZE is sparse. This aligns well with
121
- hardware, which generally expects to perform contiguous loads and
122
- computation.
132
+ individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is
133
+ considered sparse only if every element within that block is sparse.
134
+ This aligns well with hardware, which generally expects to perform
135
+ contiguous loads and computation.
123
136
124
137
This format is primarily optimized for 1. simplicity, and 2. kernel
125
138
efficiency. Notably, it is *not* optimized for size, as we believe the mask
126
139
is sufficiently small that its size is not a concern.
127
140
128
141
The essentials of our format are:
129
- num_blocks_in_row: Tensor[ROWS] # Describes the number of blocks present in
130
- each row.
131
- col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL] # col_indices[i] is the
132
- position of the blocks in index i. The values of this row after
133
- col_indices[i][num_blocks_in_row[i]] are undefined.
134
-
135
- For example, to reconstruct the original tensor from this format.
136
- ```
137
- dense_mask = torch.zeros(ROWS, COLS)
138
- for row in range(ROWS):
139
- for block_idx in range(num_blocks_in_row[row]):
140
- dense_mask[row, col_indices[row, block_idx]] = 1
141
- ```
142
+
143
+ - num_blocks_in_row: Tensor[ROWS]
144
+ Describes the number of blocks present in each row.
145
+
146
+ - col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]
147
+ `col_indices[i]` is the sequence of block positions for row i. The values of
148
+ this row after `col_indices[i][num_blocks_in_row[i]]` are undefined.
149
+
150
+ For example, to reconstruct the original tensor from this format:
151
+
152
+ .. code-block:: python
153
+
154
+ dense_mask = torch.zeros(ROWS, COLS)
155
+ for row in range(ROWS):
156
+ for block_idx in range(num_blocks_in_row[row]):
157
+ dense_mask[row, col_indices[row, block_idx]] = 1
142
158
143
159
Notably, this format makes it easier to implement a reduction along the
144
160
*rows* of the mask.
145
161
146
- ## Details
147
- The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs:
148
-
149
- (kv_num_blocks, kv_indices): This is used for the forwards pass of
150
- attention, as we reduce along the KV dimension.
151
- (q_num_blocks, q_indices): This is required for the backwards pass, as
152
- computing dKV requires iterating along the mask along the Q dimension.
153
- [OPTIONAL](full_kv_num_blocks, full_kv_indices): This is optional, and is
154
- purely an optimization. As it turns out, applying masking to every block is
155
- quite expensive! If we specifically know which blocks are "full" and don't
156
- require masking at all, then we can skip applying mask_mod to these blocks.
157
- This requires the user to split out a separate mask_mod from the score_mod.
158
- For causal masks, this is about a 15% speedup.
159
- [OPTIONAL](full_q_num_blocks, full_q_indices): Same as above, but for the
160
- backwards.
162
+ Details
163
+ -------
164
+ The basics of our format require only kv_num_blocks and kv_indices. But, we
165
+ have up to 8 tensors on this object. This represents 4 pairs:
166
+
167
+ 1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as
168
+ we reduce along the KV dimension.
169
+
170
+ 2. (q_num_blocks, q_indices): Required for the backwards pass, as computing
171
+ dKV requires iterating along the mask along the Q dimension.
172
+
173
+ 3. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and
174
+ purely an optimization. As it turns out, applying masking to every block
175
+ is quite expensive! If we specifically know which blocks are "full" and
176
+ don't require masking at all, then we can skip applying mask_mod to these
177
+ blocks. This requires the user to split out a separate mask_mod from the
178
+ score_mod. For causal masks, this is about a 15% speedup.
161
179
180
+ 4. [OPTIONAL] (full_q_num_blocks, full_q_indices): Same as above, but for
181
+ the backwards pass.
162
182
"""
163
183
kv_num_blocks : Tensor
164
184
kv_indices : Tensor
@@ -184,7 +204,7 @@ def __init__(
184
204
full_q_indices : Optional [Tensor ],
185
205
KV_BLOCK_SIZE = _DEFAULT_SPARSE_BLOCK_SIZE ,
186
206
Q_BLOCK_SIZE = _DEFAULT_SPARSE_BLOCK_SIZE ,
187
- mask_fn = None ,
207
+ mask_fn : Optional [ _mask_fn_signature ] = None ,
188
208
):
189
209
if kv_indices .dim () < 2 :
190
210
raise RuntimeError ("BlockMask must have at least 2 dimensions" )
@@ -469,7 +489,7 @@ def create_mask(
469
489
r"""This function creates a mask tensor from a mod_fn function.
470
490
471
491
Args:
472
- mod_fn (Callable ): Function to modify attention scores.
492
+ mod_fn (Union[_score_mod_signature, _mask_fn_signature] ): Function to modify attention scores.
473
493
B (int): Batch size.
474
494
H (int): Number of heads.
475
495
M (int): Sequence length of query.
@@ -491,16 +511,16 @@ def create_mask(
491
511
ctx = nullcontext ()
492
512
else :
493
513
ctx = TransformGetItemToIndex () # type: ignore[assignment]
494
- mod_type = get_mod_type (mod_fn )
514
+ mod_type = _get_mod_type (mod_fn )
495
515
496
516
with ctx :
497
- if mod_type == ModificationType .SCORE_MOD :
517
+ if mod_type == _ModificationType .SCORE_MOD :
498
518
score_mod = mod_fn
499
519
score_mod = _vmap_for_bhqkv (score_mod , prefix = (0 ,)) # first input is score
500
520
out = score_mod (torch .zeros (B , H , M , N , device = device ), b , h , m , n )
501
521
mask = torch .where (torch .isneginf (out ), False , True )
502
522
return mask
503
- elif mod_type == ModificationType .MASK_FN :
523
+ elif mod_type == _ModificationType .MASK_FN :
504
524
mask_fn = mod_fn
505
525
mask_fn = _vmap_for_bhqkv (mask_fn , prefix = ())
506
526
mask = mask_fn (b , h , m , n )
@@ -515,8 +535,8 @@ def _create_block_mask_inner(
515
535
mod_fn , B , H , M , N , device , KV_BLOCK_SIZE , Q_BLOCK_SIZE , mod_type
516
536
):
517
537
mask_tensor = create_mask (mod_fn , B , H , M , N , device , _compile = True )
518
- mod_type = get_mod_type (mod_fn )
519
- if mod_type == ModificationType .MASK_FN :
538
+ mod_type = _get_mod_type (mod_fn )
539
+ if mod_type == _ModificationType .MASK_FN :
520
540
mask_fn = mod_fn
521
541
else :
522
542
mask_fn = None
@@ -558,7 +578,7 @@ def create_block_mask(
558
578
block_mask (tuple): A tuple of (kv_num_blocks, kv_indices, q_num_blocks, q_indices,
559
579
KV_BLOCK_SIZE, Q_BLOCK_SIZE) which represents the block mask.
560
580
"""
561
- mod_type = get_mod_type (fn )
581
+ mod_type = _get_mod_type (fn )
562
582
inner_func = _create_block_mask_inner
563
583
# This is kind of a temporary hack to workaround some issues
564
584
if _compile :
@@ -618,14 +638,14 @@ def score_mod(
618
638
score: Tensor,
619
639
batch: Tensor,
620
640
head: Tensor,
621
- token_q : Tensor,
622
- token_kv : Tensor
641
+ q_idx : Tensor,
642
+ kv_idx : Tensor
623
643
) -> Tensor:
624
644
625
645
Where:
626
646
- ``score``: A scalar tensor representing the attention score,
627
647
with the same data type and device as the query, key, and value tensors.
628
- - ``b ``, ``h ``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
648
+ - ``batch ``, ``head ``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
629
649
the batch index, head index, query index, and key/value index, respectively.
630
650
These should have the ``torch.int`` data type and be located on the same device as the score tensor.
631
651
0 commit comments