Skip to content

Commit f4ef5ae

Browse files
committed
Fix get_scalar_type() for complex dtypes
1 parent 2c8876a commit f4ef5ae

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def is_float_dtype(dtype):
171171

172172

173173
def get_scalar_type(dtype: DataType) -> ScalarType:
174-
if is_int_dtype(dtype):
174+
if dtype in all_int_dtypes:
175175
return int
176-
elif is_float_dtype(dtype):
176+
elif dtype in float_dtypes:
177177
return float
178178
elif dtype in complex_dtypes:
179179
return complex

0 commit comments

Comments
 (0)