Skip to content

Commit e70891b

Browse files
Adjusted test_nonzero_dtype to use default index type as reference
1 parent bf22a28 commit e70891b

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import dpctl
2424
import dpctl.tensor as dpt
25+
import dpctl.tensor._tensor_impl as ti
2526
from dpctl.utils import ExecutionPlacementError
2627

2728
_all_dtypes = [
@@ -1353,7 +1354,7 @@ def test_nonzero_dtype():
13531354
x = dpt.ones((3, 4))
13541355
idx, idy = dpt.nonzero(x)
13551356
# create array using device's
1356-
# default integral data type
1357-
ref = dpt.arange(8)
1358-
assert idx.dtype == ref.dtype
1359-
assert idy.dtype == ref.dtype
1357+
# default index data type
1358+
index_dt = dpt.dtype(ti.default_device_index_type(x.sycl_queue))
1359+
assert idx.dtype == index_dt
1360+
assert idy.dtype == index_dt

0 commit comments

Comments
 (0)