|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
1 | 7 | import unittest
|
2 | 8 |
|
3 | 9 | import torch
|
@@ -57,23 +63,25 @@ def test_simple(self, is_dynamic_shape=False):
|
57 | 63 | self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
|
58 | 64 | float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
|
59 | 65 | quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
|
60 |
| - self.assertTrue( |
61 |
| - torch.allclose( |
62 |
| - float_out, |
63 |
| - quantized_out, |
64 |
| - ) |
| 66 | + torch.testing.assert_close( |
| 67 | + float_out, |
| 68 | + quantized_out, |
| 69 | + # had to adjust rtol because switching to using custom_sdpa means we |
| 70 | + # will use dequantized k and v instead of original k and v |
| 71 | + # this leads to larger differences in the output. |
| 72 | + # subsequent diff in the stack will address this issue. |
| 73 | + rtol=1e-01, |
| 74 | + atol=1e-03, |
65 | 75 | )
|
66 | 76 |
|
67 | 77 | input_pos = torch.tensor([3], dtype=torch.int64)
|
68 | 78 | self.seq_len = 1
|
69 | 79 | q, k, v = self._init_kv()
|
70 | 80 | float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
|
71 | 81 | quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
|
72 |
| - self.assertTrue( |
73 |
| - torch.allclose( |
74 |
| - float_out, |
75 |
| - quantized_out, |
76 |
| - rtol=1e-03, |
77 |
| - atol=1e-03, |
78 |
| - ) |
| 82 | + torch.testing.assert_close( |
| 83 | + float_out, |
| 84 | + quantized_out, |
| 85 | + rtol=1e-03, |
| 86 | + atol=1e-03, |
79 | 87 | )
|
0 commit comments