@@ -392,17 +392,50 @@ def setUp(self):
392
392
self .max_seq_len = 2048
393
393
self .setup_caches ()
394
394
395
+ def _scale_tensor (self , tensor , min_value , max_value , scale = True ):
396
+ normalized_tensor = (tensor - tensor .min ()) / (tensor .max () - tensor .min ())
397
+
398
+ scaled_tensor = normalized_tensor * (max_value - min_value ) + min_value
399
+
400
+ return scaled_tensor if scale else tensor
401
+
395
402
def _test_sdpa_common (
396
- self , n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , next_iter_seq_len = 1
403
+ self ,
404
+ n_heads_kv ,
405
+ n_heads_q ,
406
+ head_dim ,
407
+ max_seq_len ,
408
+ seq_len ,
409
+ next_iter_seq_len = 1 ,
410
+ scale_tensors = False ,
397
411
):
412
+ # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
413
+ tensor_scale_max = 20
414
+ tensor_scale_min = - 20
398
415
self .n_heads_kv = n_heads_kv
399
416
self .n_heads_q = n_heads_q
400
417
self .head_dim = head_dim
401
418
self .max_seq_len = max_seq_len
402
419
self .setup_caches ()
403
- q = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
404
- k = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
405
- v = torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim ))
420
+ q = self ._scale_tensor (
421
+ torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim )),
422
+ tensor_scale_max ,
423
+ tensor_scale_min ,
424
+ scale_tensors ,
425
+ )
426
+ k = self ._scale_tensor (
427
+ torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim )),
428
+ tensor_scale_max ,
429
+ tensor_scale_min ,
430
+ scale_tensors ,
431
+ )
432
+ v = self ._scale_tensor (
433
+ torch .rand ((1 , seq_len , self .n_heads_kv , self .head_dim )),
434
+ tensor_scale_max ,
435
+ tensor_scale_min ,
436
+ scale_tensors ,
437
+ )
438
+
406
439
start_pos = 0
407
440
attn_mask = self .mask [start_pos : start_pos + seq_len , :]
408
441
attn_mask = attn_mask [:, : start_pos + seq_len ]
@@ -412,11 +445,27 @@ def _test_sdpa_common(
412
445
op_output = torch .ops .llama .sdpa_with_kv_cache (
413
446
q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
414
447
)
415
- self .assertTrue (torch .allclose (ref_output , op_output ))
448
+ self .assertTrue (torch .allclose (ref_output , op_output , atol = 1e-6 ))
449
+
450
+ q = self ._scale_tensor (
451
+ torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim )),
452
+ tensor_scale_max ,
453
+ tensor_scale_min ,
454
+ scale_tensors ,
455
+ )
456
+ k = self ._scale_tensor (
457
+ torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim )),
458
+ tensor_scale_max ,
459
+ tensor_scale_min ,
460
+ scale_tensors ,
461
+ )
462
+ v = self ._scale_tensor (
463
+ torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim )),
464
+ tensor_scale_max ,
465
+ tensor_scale_min ,
466
+ scale_tensors ,
467
+ )
416
468
417
- q = torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim ))
418
- k = torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim ))
419
- v = torch .rand ((1 , next_iter_seq_len , self .n_heads_kv , self .head_dim ))
420
469
start_pos = seq_len
421
470
seq_len = q .size (1 )
422
471
attn_mask = self .mask [start_pos : start_pos + seq_len , :]
@@ -427,7 +476,7 @@ def _test_sdpa_common(
427
476
op_output = torch .ops .llama .sdpa_with_kv_cache (
428
477
q , k , v , self .k_cache , self .v_cache , start_pos , seq_len , None , 0 , True
429
478
)
430
- self .assertTrue (torch .allclose (ref_output , op_output ))
479
+ self .assertTrue (torch .allclose (ref_output , op_output , atol = 1e-6 ))
431
480
432
481
433
482
class SDPATestForLargeSeqLength (SDPATestCommon ):
@@ -438,7 +487,9 @@ def test_sdpa_with_cache_seq_len_130(self):
438
487
head_dim = 128
439
488
max_seq_len = 2048
440
489
seq_len = 130
441
- self ._test_sdpa_common (n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len )
490
+ self ._test_sdpa_common (
491
+ n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , True
492
+ )
442
493
443
494
def test_sdpa_with_cache_seq_len_small (self ):
444
495
n_heads_kv = 4
@@ -462,7 +513,9 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
462
513
head_dim = 128
463
514
max_seq_len = 2048
464
515
seq_len = 130
465
- self ._test_sdpa_common (n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len )
516
+ self ._test_sdpa_common (
517
+ n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , True
518
+ )
466
519
467
520
def test_sdpa_with_cache_seq_len_llava_example_gqa (self ):
468
521
n_heads_kv = 16
@@ -483,7 +536,13 @@ def test_sdpa_with_cache_seq_len_130(self):
483
536
seq_len = 130
484
537
next_iter_seq_len = 17
485
538
self ._test_sdpa_common (
486
- n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , next_iter_seq_len
539
+ n_heads_kv ,
540
+ n_heads_q ,
541
+ head_dim ,
542
+ max_seq_len ,
543
+ seq_len ,
544
+ next_iter_seq_len ,
545
+ True ,
487
546
)
488
547
489
548
def test_sdpa_with_cache_seq_len_llava_example (self ):
@@ -505,7 +564,13 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
505
564
seq_len = 130
506
565
next_iter_seq_len = 33
507
566
self ._test_sdpa_common (
508
- n_heads_kv , n_heads_q , head_dim , max_seq_len , seq_len , next_iter_seq_len
567
+ n_heads_kv ,
568
+ n_heads_q ,
569
+ head_dim ,
570
+ max_seq_len ,
571
+ seq_len ,
572
+ next_iter_seq_len ,
573
+ True ,
509
574
)
510
575
511
576
def test_sdpa_with_cache_seq_len_llava_example_gqa (self ):
0 commit comments