@@ -23,7 +23,7 @@ def forward(self, query, key, value):
23
23
key = torch .rand (key_shape , dtype = torch .float16 )
24
24
value = torch .rand (key_shape , dtype = torch .float16 )
25
25
inputs .extend ([query , key , value ])
26
- self .run_test (SDPA (), inputs , precision = torch .float16 )
26
+ self .run_test (SDPA (), inputs , rtol = 1e-2 , atol = 1e-2 , precision = torch .float16 )
27
27
28
28
@parameterized .expand ([((32 , 8 , 128 , 64 ), (32 , 8 , 128 , 64 ))])
29
29
def test_sdpa_causal (self , query_shape , key_shape ):
@@ -38,7 +38,7 @@ def forward(self, query, key, value):
38
38
key = torch .rand (key_shape , dtype = torch .float16 )
39
39
value = torch .rand (key_shape , dtype = torch .float16 )
40
40
inputs .extend ([query , key , value ])
41
- self .run_test (SDPA (), inputs , precision = torch .float16 )
41
+ self .run_test (SDPA (), inputs , rtol = 1e-2 , atol = 1e-2 , precision = torch .float16 )
42
42
43
43
44
44
@unittest .skipIf (
@@ -69,6 +69,8 @@ def forward(self, query, key, value):
69
69
self .run_test (
70
70
SDPA (),
71
71
inputs ,
72
+ rtol = 1e-2 ,
73
+ atol = 1e-2 ,
72
74
precision = torch .float16 ,
73
75
enable_passes = True ,
74
76
)
@@ -99,6 +101,8 @@ def forward(self, query, key, value):
99
101
self .run_test (
100
102
SDPA (),
101
103
inputs ,
104
+ rtol = 1e-2 ,
105
+ atol = 1e-2 ,
102
106
precision = torch .float16 ,
103
107
enable_passes = True ,
104
108
)
0 commit comments