@@ -186,3 +186,47 @@ def test_binary_func_arg_validation():
186
186
with pytest .raises (ValueError ):
187
187
dpt .add (a , Ellipsis )
188
188
dpt .add (a , a , order = "invalid" )
189
+
190
+
191
+ def test_all_data_types ():
192
+ fp16_fp64_types = set ([dpt .float16 , dpt .float64 , dpt .complex128 ])
193
+ fp64_types = set ([dpt .float64 , dpt .complex128 ])
194
+
195
+ all_dts = tu ._all_data_types (True , True )
196
+ assert fp16_fp64_types .issubset (all_dts )
197
+
198
+ all_dts = tu ._all_data_types (True , False )
199
+ assert dpt .float16 in all_dts
200
+ assert not fp64_types .issubset (all_dts )
201
+
202
+ all_dts = tu ._all_data_types (False , True )
203
+ assert dpt .float16 not in all_dts
204
+ assert fp64_types .issubset (all_dts )
205
+
206
+ all_dts = tu ._all_data_types (False , False )
207
+ assert not fp16_fp64_types .issubset (all_dts )
208
+
209
+
210
+ @pytest .mark .parametrize ("fp16" , [True , False ])
211
+ @pytest .mark .parametrize ("fp64" , [True , False ])
212
+ def test_maximal_inexact_types (fp16 , fp64 ):
213
+ assert not tu ._is_maximal_inexact_type (dpt .int32 , fp16 , fp64 )
214
+ assert fp64 == tu ._is_maximal_inexact_type (dpt .float64 , fp16 , fp64 )
215
+ assert fp64 == tu ._is_maximal_inexact_type (dpt .complex128 , fp16 , fp64 )
216
+ assert fp64 != tu ._is_maximal_inexact_type (dpt .float32 , fp16 , fp64 )
217
+ assert fp64 != tu ._is_maximal_inexact_type (dpt .complex64 , fp16 , fp64 )
218
+
219
+
220
+ def test_can_cast_device ():
221
+ assert tu ._can_cast (dpt .int64 , dpt .float64 , True , True )
222
+ # if f8 is available, can't cast i8 to f4
223
+ assert not tu ._can_cast (dpt .int64 , dpt .float32 , True , True )
224
+ assert not tu ._can_cast (dpt .int64 , dpt .float32 , False , True )
225
+ # should be able to cast to f8 when f2 unavailable
226
+ assert tu ._can_cast (dpt .int64 , dpt .float64 , False , True )
227
+ # casting to f4 acceptable when f8 unavailable
228
+ assert tu ._can_cast (dpt .int64 , dpt .float32 , True , False )
229
+ assert tu ._can_cast (dpt .int64 , dpt .float32 , False , False )
230
+ # can't safely cast inexact type to inexact type of lesser precision
231
+ assert not tu ._can_cast (dpt .float32 , dpt .float16 , True , False )
232
+ assert not tu ._can_cast (dpt .float64 , dpt .float32 , False , True )
0 commit comments