Skip to content

Commit 6f0969c

Browse files
Merge pull request #1336 from IntelPython/fix-nonzero-result-dtype-win
Fix nonzero result dtype win
2 parents a8d97f3 + e70891b commit 6f0969c

File tree

5 files changed

+23
-6
lines changed

5 files changed

+23
-6
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _nonzero_impl(ary):
586586
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
587587
)
588588
mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q)
589-
indexes_dt = ti.default_device_int_type(exec_q.sycl_device)
589+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
590590
indexes = dpt.empty(
591591
(ary.ndim, mask_count),
592592
dtype=indexes_dt,

dpctl/tensor/libtensor/source/device_support_queries.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ std::string _default_device_bool_type(sycl::device)
7171
return "b1";
7272
}
7373

74+
std::string _default_device_index_type(sycl::device)
75+
{
76+
return "i8";
77+
}
78+
7479
sycl::device _extract_device(py::object arg)
7580
{
7681
auto const &api = dpctl::detail::dpctl_capi::get();
@@ -115,6 +120,12 @@ std::string default_device_complex_type(py::object arg)
115120
return _default_device_complex_type(d);
116121
}
117122

123+
std::string default_device_index_type(py::object arg)
124+
{
125+
sycl::device d = _extract_device(arg);
126+
return _default_device_index_type(d);
127+
}
128+
118129
} // namespace py_internal
119130
} // namespace tensor
120131
} // namespace dpctl

dpctl/tensor/libtensor/source/device_support_queries.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ extern std::string default_device_fp_type(py::object);
4141
extern std::string default_device_int_type(py::object);
4242
extern std::string default_device_bool_type(py::object);
4343
extern std::string default_device_complex_type(py::object);
44+
extern std::string default_device_index_type(py::object);
4445

4546
} // namespace py_internal
4647
} // namespace tensor

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,13 @@ PYBIND11_MODULE(_tensor_impl, m)
297297

298298
m.def("default_device_complex_type",
299299
dpctl::tensor::py_internal::default_device_complex_type,
300-
"Gives default complex floating point type support by device.",
300+
"Gives default complex floating point type supported by device.",
301301
py::arg("dev"));
302302

303+
m.def("default_device_index_type",
304+
dpctl::tensor::py_internal::default_device_index_type,
305+
"Gives default index type supported by device.", py::arg("dev"));
306+
303307
auto tril_fn = [](dpctl::tensor::usm_ndarray src,
304308
dpctl::tensor::usm_ndarray dst, py::ssize_t k,
305309
sycl::queue exec_q,

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import dpctl
2424
import dpctl.tensor as dpt
25+
import dpctl.tensor._tensor_impl as ti
2526
from dpctl.utils import ExecutionPlacementError
2627

2728
_all_dtypes = [
@@ -1353,7 +1354,7 @@ def test_nonzero_dtype():
13531354
x = dpt.ones((3, 4))
13541355
idx, idy = dpt.nonzero(x)
13551356
# create array using device's
1356-
# default integral data type
1357-
ref = dpt.arange(8)
1358-
assert idx.dtype == ref.dtype
1359-
assert idy.dtype == ref.dtype
1357+
# default index data type
1358+
index_dt = dpt.dtype(ti.default_device_index_type(x.sycl_queue))
1359+
assert idx.dtype == index_dt
1360+
assert idy.dtype == index_dt

0 commit comments

Comments
 (0)