Skip to content

Commit e28bd88

Browse files
Add default_device_index_type(queue_or_dev) utility
This returns default index type for give device. Since all devices are 64-bit devices, it always returns "i8".
1 parent ed93e02 commit e28bd88

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

dpctl/tensor/libtensor/source/device_support_queries.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ std::string _default_device_bool_type(sycl::device)
7171
return "b1";
7272
}
7373

74+
std::string _default_device_index_type(sycl::device)
75+
{
76+
return "i8";
77+
}
78+
7479
sycl::device _extract_device(py::object arg)
7580
{
7681
auto const &api = dpctl::detail::dpctl_capi::get();
@@ -115,6 +120,12 @@ std::string default_device_complex_type(py::object arg)
115120
return _default_device_complex_type(d);
116121
}
117122

123+
std::string default_device_index_type(py::object arg)
124+
{
125+
sycl::device d = _extract_device(arg);
126+
return _default_device_index_type(d);
127+
}
128+
118129
} // namespace py_internal
119130
} // namespace tensor
120131
} // namespace dpctl

dpctl/tensor/libtensor/source/device_support_queries.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ extern std::string default_device_fp_type(py::object);
4141
extern std::string default_device_int_type(py::object);
4242
extern std::string default_device_bool_type(py::object);
4343
extern std::string default_device_complex_type(py::object);
44+
extern std::string default_device_index_type(py::object);
4445

4546
} // namespace py_internal
4647
} // namespace tensor

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,13 @@ PYBIND11_MODULE(_tensor_impl, m)
297297

298298
m.def("default_device_complex_type",
299299
dpctl::tensor::py_internal::default_device_complex_type,
300-
"Gives default complex floating point type support by device.",
300+
"Gives default complex floating point type supported by device.",
301301
py::arg("dev"));
302302

303+
m.def("default_device_index_type",
304+
dpctl::tensor::py_internal::default_device_index_type,
305+
"Gives default index type supported by device.", py::arg("dev"));
306+
303307
auto tril_fn = [](dpctl::tensor::usm_ndarray src,
304308
dpctl::tensor::usm_ndarray dst, py::ssize_t k,
305309
sycl::queue exec_q,

0 commit comments

Comments
 (0)