Skip to content

Commit 47ac24b

Browse files
authored
Changing sdpa_with_kv_cache tests to use a wider dynamic range.
Differential Revision: D61403179 Pull Request resolved: #4892
1 parent 0a8547a commit 47ac24b

File tree

1 file changed

+78
-13
lines changed

1 file changed

+78
-13
lines changed

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -392,17 +392,50 @@ def setUp(self):
392392
self.max_seq_len = 2048
393393
self.setup_caches()
394394

395+
def _scale_tensor(self, tensor, min_value, max_value, scale=True):
396+
normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
397+
398+
scaled_tensor = normalized_tensor * (max_value - min_value) + min_value
399+
400+
return scaled_tensor if scale else tensor
401+
395402
def _test_sdpa_common(
396-
self, n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len=1
403+
self,
404+
n_heads_kv,
405+
n_heads_q,
406+
head_dim,
407+
max_seq_len,
408+
seq_len,
409+
next_iter_seq_len=1,
410+
scale_tensors=False,
397411
):
412+
# Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
413+
tensor_scale_max = 20
414+
tensor_scale_min = -20
398415
self.n_heads_kv = n_heads_kv
399416
self.n_heads_q = n_heads_q
400417
self.head_dim = head_dim
401418
self.max_seq_len = max_seq_len
402419
self.setup_caches()
403-
q = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
404-
k = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
405-
v = torch.rand((1, seq_len, self.n_heads_kv, self.head_dim))
420+
q = self._scale_tensor(
421+
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
422+
tensor_scale_max,
423+
tensor_scale_min,
424+
scale_tensors,
425+
)
426+
k = self._scale_tensor(
427+
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
428+
tensor_scale_max,
429+
tensor_scale_min,
430+
scale_tensors,
431+
)
432+
v = self._scale_tensor(
433+
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
434+
tensor_scale_max,
435+
tensor_scale_min,
436+
scale_tensors,
437+
)
438+
406439
start_pos = 0
407440
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
408441
attn_mask = attn_mask[:, : start_pos + seq_len]
@@ -412,11 +445,27 @@ def _test_sdpa_common(
412445
op_output = torch.ops.llama.sdpa_with_kv_cache(
413446
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
414447
)
415-
self.assertTrue(torch.allclose(ref_output, op_output))
448+
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))
449+
450+
q = self._scale_tensor(
451+
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
452+
tensor_scale_max,
453+
tensor_scale_min,
454+
scale_tensors,
455+
)
456+
k = self._scale_tensor(
457+
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
458+
tensor_scale_max,
459+
tensor_scale_min,
460+
scale_tensors,
461+
)
462+
v = self._scale_tensor(
463+
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
464+
tensor_scale_max,
465+
tensor_scale_min,
466+
scale_tensors,
467+
)
416468

417-
q = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
418-
k = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
419-
v = torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim))
420469
start_pos = seq_len
421470
seq_len = q.size(1)
422471
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
@@ -427,7 +476,7 @@ def _test_sdpa_common(
427476
op_output = torch.ops.llama.sdpa_with_kv_cache(
428477
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
429478
)
430-
self.assertTrue(torch.allclose(ref_output, op_output))
479+
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))
431480

432481

433482
class SDPATestForLargeSeqLength(SDPATestCommon):
@@ -438,7 +487,9 @@ def test_sdpa_with_cache_seq_len_130(self):
438487
head_dim = 128
439488
max_seq_len = 2048
440489
seq_len = 130
441-
self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
490+
self._test_sdpa_common(
491+
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
492+
)
442493

443494
def test_sdpa_with_cache_seq_len_small(self):
444495
n_heads_kv = 4
@@ -462,7 +513,9 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
462513
head_dim = 128
463514
max_seq_len = 2048
464515
seq_len = 130
465-
self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
516+
self._test_sdpa_common(
517+
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
518+
)
466519

467520
def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
468521
n_heads_kv = 16
@@ -483,7 +536,13 @@ def test_sdpa_with_cache_seq_len_130(self):
483536
seq_len = 130
484537
next_iter_seq_len = 17
485538
self._test_sdpa_common(
486-
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
539+
n_heads_kv,
540+
n_heads_q,
541+
head_dim,
542+
max_seq_len,
543+
seq_len,
544+
next_iter_seq_len,
545+
True,
487546
)
488547

489548
def test_sdpa_with_cache_seq_len_llava_example(self):
@@ -505,7 +564,13 @@ def test_sdpa_with_cache_seq_len_130_gqa(self):
505564
seq_len = 130
506565
next_iter_seq_len = 33
507566
self._test_sdpa_common(
508-
n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
567+
n_heads_kv,
568+
n_heads_q,
569+
head_dim,
570+
max_seq_len,
571+
seq_len,
572+
next_iter_seq_len,
573+
True,
509574
)
510575

511576
def test_sdpa_with_cache_seq_len_llava_example_gqa(self):

0 commit comments

Comments
 (0)