Skip to content

Commit 95b345a

Browse files
Default int type queries behave based on NumPy version
Introduced default_device_uint_type, parallel to default_device_int_type. For NumPy >= 2 (as checked at runtime), it returns "i8" (since dpctl is only supported on x86_64) or unsigned "u8", while for NumPy < 2 it returns long ("l"), or 'unsigned long' ("L").
1 parent 7b64374 commit 95b345a

File tree

3 files changed

+54
-4
lines changed

3 files changed

+54
-4
lines changed

dpctl/tensor/libtensor/source/device_support_queries.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,48 @@ std::string _default_device_fp_type(const sycl::device &d)
4949
}
5050
}
5151

52+
int get_numpy_major_version()
53+
{
54+
namespace py = pybind11;
55+
56+
py::module_ numpy = py::module_::import("numpy");
57+
py::str version_string = numpy.attr("__version__");
58+
py::module_ numpy_lib = py::module_::import("numpy.lib");
59+
60+
py::object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
61+
int major_version = numpy_version.attr("major").cast<int>();
62+
63+
return major_version;
64+
}
65+
5266
std::string _default_device_int_type(const sycl::device &)
5367
{
54-
return "l"; // code for numpy.dtype('long') to be consistent
55-
// with NumPy's default integer type across
56-
// platforms.
68+
const int np_ver = get_numpy_major_version();
69+
70+
if (np_ver >= 2) {
71+
return "i8";
72+
}
73+
else {
74+
// code for numpy.dtype('long') to be consistent
75+
// with NumPy's default integer type across
76+
// platforms.
77+
return "l";
78+
}
79+
}
80+
81+
std::string _default_device_uint_type(const sycl::device &)
82+
{
83+
const int np_ver = get_numpy_major_version();
84+
85+
if (np_ver >= 2) {
86+
return "u8";
87+
}
88+
else {
89+
// code for numpy.dtype('long') to be consistent
90+
// with NumPy's default integer type across
91+
// platforms.
92+
return "L";
93+
}
5794
}
5895

5996
std::string _default_device_complex_type(const sycl::device &d)
@@ -108,6 +145,12 @@ std::string default_device_int_type(const py::object &arg)
108145
return _default_device_int_type(d);
109146
}
110147

148+
std::string default_device_uint_type(const py::object &arg)
149+
{
150+
const sycl::device &d = _extract_device(arg);
151+
return _default_device_uint_type(d);
152+
}
153+
111154
std::string default_device_bool_type(const py::object &arg)
112155
{
113156
const sycl::device &d = _extract_device(arg);

dpctl/tensor/libtensor/source/device_support_queries.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ namespace py_internal
3939

4040
extern std::string default_device_fp_type(const py::object &);
4141
extern std::string default_device_int_type(const py::object &);
42+
extern std::string default_device_uint_type(const py::object &);
4243
extern std::string default_device_bool_type(const py::object &);
4344
extern std::string default_device_complex_type(const py::object &);
4445
extern std::string default_device_index_type(const py::object &);

dpctl/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,13 @@ PYBIND11_MODULE(_tensor_impl, m)
331331

332332
m.def("default_device_int_type",
333333
dpctl::tensor::py_internal::default_device_int_type,
334-
"Gives default integer type supported by device.", py::arg("dev"));
334+
"Gives default signed integer type supported by device.",
335+
py::arg("dev"));
336+
337+
m.def("default_device_uint_type",
338+
dpctl::tensor::py_internal::default_device_uint_type,
339+
"Gives default unsigned integer type supported by device.",
340+
py::arg("dev"));
335341

336342
m.def("default_device_bool_type",
337343
dpctl::tensor::py_internal::default_device_bool_type,

0 commit comments

Comments
 (0)