Skip to content

Commit bbf7b76

Browse files
committed
Update base for Update on "Dont quantize the current token for attention"
Differential Revision: [D63497872](https://our.internmc.facebook.com/intern/diff/D63497872/) [ghstack-poisoned]
1 parent b690a36 commit bbf7b76

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

examples/models/llama2/source_transformation/test_sdpa_with_quantized_kv_cache.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import unittest
28

39
import torch
@@ -57,23 +63,25 @@ def test_simple(self, is_dynamic_shape=False):
5763
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
5864
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
5965
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
60-
self.assertTrue(
61-
torch.allclose(
62-
float_out,
63-
quantized_out,
64-
)
66+
torch.testing.assert_close(
67+
float_out,
68+
quantized_out,
69+
# had to adjust rtol because switching to using custom_sdpa means we
70+
# will use dequantized k and v instead of original k and v
71+
# this leads to larger differences in the output.
72+
# subsequent diff in the stack will address this issue.
73+
rtol=1e-01,
74+
atol=1e-03,
6575
)
6676

6777
input_pos = torch.tensor([3], dtype=torch.int64)
6878
self.seq_len = 1
6979
q, k, v = self._init_kv()
7080
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
7181
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
72-
self.assertTrue(
73-
torch.allclose(
74-
float_out,
75-
quantized_out,
76-
rtol=1e-03,
77-
atol=1e-03,
78-
)
82+
torch.testing.assert_close(
83+
float_out,
84+
quantized_out,
85+
rtol=1e-03,
86+
atol=1e-03,
7987
)

0 commit comments

Comments
 (0)