@@ -365,3 +365,135 @@ def test_sdpa_with_cache_mqa_3(self):
365
365
q , k , v , self .k_cache , self .v_cache , 1 , 1 , None , 0 , False
366
366
)
367
367
self .assertTrue (torch .allclose (ref_output , op_output ))
368
+
369
+
370
+ class SDPATestForLargeSeqLength (unittest .TestCase ):
371
+
372
+ def setup_caches (self ):
373
+ self .k_cache = torch .zeros (
374
+ (1 , self .max_seq_len , self .n_heads_kv , self .head_dim )
375
+ )
376
+ self .v_cache = torch .zeros (
377
+ (1 , self .max_seq_len , self .n_heads_kv , self .head_dim )
378
+ )
379
+ self .mask = torch .full (
380
+ (self .max_seq_len , self .max_seq_len ),
381
+ float ("-inf" ),
382
+ )
383
+ self .mask = torch .triu (self .mask , diagonal = 1 )
384
+
385
+ def setUp (self ):
386
+ torch .manual_seed (42 )
387
+ self .n_heads_kv = 32
388
+ self .n_heads_q = 32
389
+ self .head_dim = 128
390
+ self .max_seq_len = 2048
391
+ self .setup_caches ()
392
+
393
+ def test_sdpa_with_cache_seq_len_130 (self ):
394
+ self .n_heads_kv = 32
395
+ self .n_heads_q = 32
396
+ self .head_dim = 128
397
+ self .max_seq_len = 2048
398
+ self .setup_caches ()
399
+ seq_len = 130
400
+ q = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
401
+ k = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
402
+ v = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
403
+ start_pos = 0
404
+ attn_mask = self .mask [start_pos : start_pos + seq_len , :]
405
+ attn_mask = attn_mask [:, : start_pos + seq_len ]
406
+ ref_output = _sdpa_with_kv_cache_ref (
407
+ q , k , v , self .k_cache , self .v_cache , attn_mask , start_pos , seq_len
408
+ )
409
+ op_output = torch .ops .llama .sdpa_with_kv_cache (
410
+ q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
411
+ )
412
+ self .assertTrue (torch .allclose (ref_output , op_output ))
413
+
414
+ q = torch .rand ((1 , 1 , self .n_heads_kv , self .head_dim ))
415
+ k = torch .rand ((1 , 1 , self .n_heads_kv , self .head_dim ))
416
+ v = torch .rand ((1 , 1 , self .n_heads_kv , self .head_dim ))
417
+ start_pos = seq_len
418
+ seq_len = q .size (1 )
419
+ attn_mask = self .mask [start_pos : start_pos + seq_len , :]
420
+ attn_mask = attn_mask [:, : start_pos + seq_len ]
421
+ ref_output = _sdpa_with_kv_cache_ref (
422
+ q , k , v , self .k_cache , self .v_cache , attn_mask , start_pos , seq_len
423
+ )
424
+ op_output = torch .ops .llama .sdpa_with_kv_cache (
425
+ q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
426
+ )
427
+ self .assertTrue (torch .allclose (ref_output , op_output ))
428
+
429
+ def test_sdpa_with_cache_seq_len_small (self ):
430
+ self .n_heads_kv = 4
431
+ self .n_heads_q = 4
432
+ self .head_dim = 4
433
+ self .max_seq_len = 8
434
+ self .setup_caches ()
435
+ q = torch .rand ((1 , 4 , self .n_heads_q , 4 ))
436
+ k = torch .rand ((1 , 4 , self .n_heads_q , 4 ))
437
+ v = torch .rand ((1 , 4 , self .n_heads_q , 4 ))
438
+ start_pos = 0
439
+ seq_len = q .size (1 )
440
+ attn_mask = self .mask [start_pos : start_pos + seq_len , :]
441
+ attn_mask = attn_mask [:, : start_pos + seq_len ]
442
+ ref_output = _sdpa_with_kv_cache_ref (
443
+ q , k , v , self .k_cache , self .v_cache , attn_mask , start_pos , seq_len
444
+ )
445
+ op_output = torch .ops .llama .sdpa_with_kv_cache (
446
+ q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
447
+ )
448
+ self .assertTrue (torch .allclose (ref_output , op_output ))
449
+
450
+ q = torch .rand ((1 , 1 , self .n_heads_q , 4 ))
451
+ k = torch .rand ((1 , 1 , self .n_heads_q , 4 ))
452
+ v = torch .rand ((1 , 1 , self .n_heads_q , 4 ))
453
+ start_pos = 4
454
+ seq_len = q .size (1 )
455
+ attn_mask = self .mask [start_pos : start_pos + seq_len , :]
456
+ attn_mask = attn_mask [:, : start_pos + seq_len ]
457
+ ref_output = _sdpa_with_kv_cache_ref (
458
+ q , k , v , self .k_cache , self .v_cache , attn_mask , start_pos , seq_len
459
+ )
460
+ op_output = torch .ops .llama .sdpa_with_kv_cache (
461
+ q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
462
+ )
463
+ self .assertTrue (torch .allclose (ref_output , op_output ))
464
+
465
+ def test_sdpa_with_cache_seq_len_llava_example (self ):
466
+ self .n_heads_kv = 32
467
+ self .n_heads_q = 32
468
+ self .head_dim = 128
469
+ self .max_seq_len = 2048
470
+ self .setup_caches ()
471
+ seq_len = 634
472
+ q = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
473
+ k = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
474
+ v = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
475
+ start_pos = 0
476
+ attn_mask = self .mask [start_pos : start_pos + seq_len , :]
477
+ attn_mask = attn_mask [:, : start_pos + seq_len ]
478
+ ref_output = _sdpa_with_kv_cache_ref (
479
+ q , k , v , self .k_cache , self .v_cache , attn_mask , start_pos , seq_len
480
+ )
481
+ op_output = torch .ops .llama .sdpa_with_kv_cache (
482
+ q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
483
+ )
484
+ self .assertTrue (torch .allclose (ref_output , op_output ))
485
+
486
+ q = torch .rand ((1 , 1 , self .n_heads_kv , self .head_dim ))
487
+ k = torch .rand ((1 , 1 , self .n_heads_kv , self .head_dim ))
488
+ v = torch .rand ((1 , 1 , self .n_heads_kv , self .head_dim ))
489
+ start_pos = seq_len
490
+ seq_len = q .size (1 )
491
+ attn_mask = self .mask [start_pos : start_pos + seq_len , :]
492
+ attn_mask = attn_mask [:, : start_pos + seq_len ]
493
+ ref_output = _sdpa_with_kv_cache_ref (
494
+ q , k , v , self .k_cache , self .v_cache , attn_mask , start_pos , seq_len
495
+ )
496
+ op_output = torch .ops .llama .sdpa_with_kv_cache (
497
+ q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
498
+ )
499
+ self .assertTrue (torch .allclose (ref_output , op_output ))
0 commit comments