Skip to content

Commit a81b7f7

Browse files
committed
Increase tolerance bound for FP16
1 parent c5cb551 commit a81b7f7

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/py/dynamo/conversion/test_attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def forward(self, query, key, value):
2323
key = torch.rand(key_shape, dtype=torch.float16)
2424
value = torch.rand(key_shape, dtype=torch.float16)
2525
inputs.extend([query, key, value])
26-
self.run_test(SDPA(), inputs, precision=torch.float16)
26+
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)
2727

2828
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
2929
def test_sdpa_causal(self, query_shape, key_shape):
@@ -38,7 +38,7 @@ def forward(self, query, key, value):
3838
key = torch.rand(key_shape, dtype=torch.float16)
3939
value = torch.rand(key_shape, dtype=torch.float16)
4040
inputs.extend([query, key, value])
41-
self.run_test(SDPA(), inputs, precision=torch.float16)
41+
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)
4242

4343

4444
@unittest.skipIf(
@@ -69,6 +69,8 @@ def forward(self, query, key, value):
6969
self.run_test(
7070
SDPA(),
7171
inputs,
72+
rtol=1e-2,
73+
atol=1e-2,
7274
precision=torch.float16,
7375
enable_passes=True,
7476
)
@@ -99,6 +101,8 @@ def forward(self, query, key, value):
99101
self.run_test(
100102
SDPA(),
101103
inputs,
104+
rtol=1e-2,
105+
atol=1e-2,
102106
precision=torch.float16,
103107
enable_passes=True,
104108
)

0 commit comments

Comments
 (0)