@@ -262,6 +262,7 @@ class BlockMask:
262
262
the backwards pass. These are autogenerated from 2.
263
263
"""
264
264
265
+ seq_lengths : Tuple [int , int ]
265
266
kv_num_blocks : Tensor
266
267
kv_indices : Tensor
267
268
full_kv_num_blocks : Optional [Tensor ]
@@ -275,6 +276,7 @@ class BlockMask:
275
276
276
277
def __init__ (
277
278
self ,
279
+ seq_lengths : Tuple [int , int ],
278
280
kv_num_blocks : Tensor ,
279
281
kv_indices : Tensor ,
280
282
full_kv_num_blocks : Optional [Tensor ],
@@ -299,6 +301,7 @@ def __init__(
299
301
full_q_indices is None
300
302
), "full_q_num_blocks and full_q_indices must be both provided or omitted"
301
303
304
+ self .seq_lengths = seq_lengths
302
305
self .kv_num_blocks = kv_num_blocks
303
306
self .kv_indices = kv_indices
304
307
self .full_kv_num_blocks = full_kv_num_blocks
@@ -319,6 +322,7 @@ def from_kv_blocks(
319
322
full_kv_indices : Optional [Tensor ] = None ,
320
323
BLOCK_SIZE : Union [int , Tuple [int , int ]] = _DEFAULT_SPARSE_BLOCK_SIZE ,
321
324
mask_mod : Optional [_mask_mod_signature ] = None ,
325
+ seq_lengths : Optional [Tuple [int , int ]] = None ,
322
326
):
323
327
"""
324
328
Creates a BlockMask instance from key-value block information.
@@ -359,8 +363,13 @@ def from_kv_blocks(
359
363
BLOCK_SIZE = (BLOCK_SIZE , BLOCK_SIZE )
360
364
361
365
mask_mod = mask_mod if mask_mod is not None else noop_mask
366
+ if seq_lengths is None :
367
+ q_length = kv_indices .shape [- 2 ] * BLOCK_SIZE [0 ]
368
+ kv_length = q_indices .shape [- 2 ] * BLOCK_SIZE [1 ]
369
+ seq_lengths = (q_length , kv_length )
362
370
363
371
return cls (
372
+ seq_lengths = seq_lengths ,
364
373
kv_num_blocks = kv_num_blocks ,
365
374
kv_indices = kv_indices ,
366
375
full_kv_num_blocks = full_kv_num_blocks ,
@@ -380,11 +389,15 @@ def as_tuple(self, flatten: bool = True):
380
389
Args:
381
390
flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
382
391
"""
383
- block_size = (
384
- (self .BLOCK_SIZE [0 ], self .BLOCK_SIZE [1 ]) if flatten else (self .BLOCK_SIZE ,)
385
- )
392
+ if flatten :
393
+ block_size = (self .BLOCK_SIZE [0 ], self .BLOCK_SIZE [1 ]) # type: ignore[assignment]
394
+ seq_lengths = (self .seq_lengths [0 ], self .seq_lengths [1 ]) # type: ignore[assignment]
395
+ else :
396
+ block_size = (self .BLOCK_SIZE ,) # type: ignore[assignment]
397
+ seq_lengths = (self .seq_lengths ,) # type: ignore[assignment]
386
398
387
399
return (
400
+ * seq_lengths ,
388
401
self .kv_num_blocks ,
389
402
self .kv_indices ,
390
403
self .full_kv_num_blocks ,
@@ -397,6 +410,11 @@ def as_tuple(self, flatten: bool = True):
397
410
self .mask_mod ,
398
411
)
399
412
413
+ @property
414
+ def shape (self ):
415
+ * batch_dims , _ , _ = self .kv_indices .shape
416
+ return tuple (batch_dims ) + self .seq_lengths
417
+
400
418
def __str__ (self ):
401
419
s = f"BlockMask(shape={ self .shape } , sparsity={ self .sparsity ():.2f} %, \n "
402
420
mask_str = self .to_string ().strip ()
@@ -457,6 +475,7 @@ def causal_mask(b, h, q_idx, kv_idx):
457
475
new_full_kv_indices ,
458
476
BLOCK_SIZE = self .BLOCK_SIZE ,
459
477
mask_mod = None ,
478
+ seq_lengths = self .seq_lengths ,
460
479
)
461
480
462
481
def __repr__ (self ):
@@ -509,14 +528,6 @@ def _adjust(self, new_q_len: int, new_kv_len: int):
509
528
self .mask_mod ,
510
529
)
511
530
512
- @property
513
- def shape (self ):
514
- """Returns the shape of the mask."""
515
- * batch_dims , q_length , _ = self .kv_indices .shape
516
- q_length = self .kv_indices .shape [- 2 ] * self .BLOCK_SIZE [0 ]
517
- kv_length = self .kv_indices .shape [- 1 ] * self .BLOCK_SIZE [1 ]
518
- return tuple (batch_dims + [q_length , kv_length ])
519
-
520
531
def numel (self ):
521
532
"""Returns the number of elements (not accounting for sparsity) in the mask."""
522
533
shape = self .shape
@@ -739,6 +750,7 @@ def _convert_block_mask_to_mask(
739
750
def _create_sparse_block_from_block_mask (
740
751
block_mask : Tuple [Tensor , Optional [Tensor ]],
741
752
mask_mod : Optional [Callable ],
753
+ seq_lengths : Tuple [int , int ],
742
754
Q_BLOCK_SIZE : int = _DEFAULT_SPARSE_BLOCK_SIZE ,
743
755
KV_BLOCK_SIZE : int = _DEFAULT_SPARSE_BLOCK_SIZE ,
744
756
) -> BlockMask :
@@ -757,6 +769,7 @@ def _create_sparse_block_from_block_mask(
757
769
full_bm [1 ],
758
770
BLOCK_SIZE = (Q_BLOCK_SIZE , KV_BLOCK_SIZE ),
759
771
mask_mod = mask_mod ,
772
+ seq_lengths = seq_lengths ,
760
773
)
761
774
762
775
@@ -878,7 +891,11 @@ def causal_mask(b, h, q_idx, kv_idx):
878
891
separate_full_blocks = True ,
879
892
)
880
893
block_mask = _create_sparse_block_from_block_mask (
881
- (partial_block_mask , full_block_mask ), mask_mod , Q_BLOCK_SIZE , KV_BLOCK_SIZE
894
+ (partial_block_mask , full_block_mask ),
895
+ mask_mod ,
896
+ (Q_LEN , KV_LEN ),
897
+ Q_BLOCK_SIZE ,
898
+ KV_BLOCK_SIZE ,
882
899
)
883
900
return block_mask
884
901
@@ -894,6 +911,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
894
911
kv_num_blocks = torch .ones ([1 , 1 , 1 ], dtype = torch .int32 , device = device ),
895
912
kv_indices = torch .zeros ([1 , 1 , 1 , 1 ], dtype = torch .int32 , device = device ),
896
913
BLOCK_SIZE = _LARGE_SPARSE_BLOCK_SIZE ,
914
+ seq_lengths = (1 , 1 ),
897
915
)
898
916
899
917
@@ -1237,29 +1255,31 @@ def score_mod(
1237
1255
1238
1256
if block_mask is None :
1239
1257
block_mask = _create_empty_block_mask (query , key )
1240
- elif (
1241
- not query .is_nested
1242
- and (query .requires_grad or key .requires_grad or value .requires_grad )
1243
- and (
1244
- query .size (- 2 )
1245
- < block_mask .kv_num_blocks .size (- 1 ) * block_mask .BLOCK_SIZE [0 ]
1246
- or key .size (- 2 ) < block_mask .kv_indices .size (- 1 ) * block_mask .BLOCK_SIZE [1 ]
1247
- )
1248
- ):
1249
- new_q_len = _round_up_to_multiple (query .size (- 2 ), block_mask .BLOCK_SIZE [0 ])
1250
- new_kv_len = _round_up_to_multiple (key .size (- 2 ), block_mask .BLOCK_SIZE [1 ])
1251
- block_mask = block_mask ._adjust (new_q_len , new_kv_len )
1252
- elif query .is_nested and (
1253
- block_mask .kv_num_blocks .size (- 1 ) * block_mask .BLOCK_SIZE [0 ]
1254
- != _round_up_to_multiple (
1255
- query ._values .size (query ._ragged_idx - 1 ), block_mask .BLOCK_SIZE [0 ] # type: ignore[attr-defined]
1256
- )
1258
+
1259
+ if (
1260
+ block_mask .BLOCK_SIZE [0 ] == _LARGE_SPARSE_BLOCK_SIZE
1261
+ and block_mask .BLOCK_SIZE [1 ] == _LARGE_SPARSE_BLOCK_SIZE
1257
1262
):
1258
- # TODO: Maybe we want to auto-adjust for this case as well?
1259
- raise RuntimeError (
1260
- f"block_mask of shape { block_mask .shape } is not compatible with nested tensor input "
1261
- f"with total sequence length of { query ._values .size (query ._ragged_idx - 1 )} " # type: ignore[attr-defined]
1262
- )
1263
+ # This corresponds to the case where we essentially have a "no-op" block mask.
1264
+ pass
1265
+ else :
1266
+ block_mask_q_len = block_mask .shape [- 2 ]
1267
+ block_mask_kv_len = block_mask .shape [- 1 ]
1268
+ if query .size (- 2 ) > block_mask_q_len or key .size (- 2 ) > block_mask_kv_len :
1269
+ raise ValueError (
1270
+ f"block_mask was created for block_mask.shape={ block_mask .shape } but got q_len={ query .size (- 2 )} and kv_len={ key .size (- 2 )} . "
1271
+ "As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask."
1272
+ )
1273
+ elif (
1274
+ query .size (- 2 ) < block_mask_q_len and key .size (- 2 ) <= block_mask_kv_len
1275
+ ) or (query .size (- 2 ) <= block_mask_q_len and key .size (- 2 ) < block_mask_kv_len ):
1276
+ raise ValueError (
1277
+ f"block_mask was created for block_mask.shape={ block_mask .shape } but got q_len={ query .size (- 2 )} and kv_len={ key .size (- 2 )} . "
1278
+ "As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!"
1279
+ )
1280
+ assert query .size (- 2 ) == block_mask_q_len
1281
+ assert key .size (- 2 ) == block_mask_kv_len
1282
+
1263
1283
if scale is None :
1264
1284
scale = 1.0 / math .sqrt (query .size (- 1 ))
1265
1285
0 commit comments