@@ -49,11 +49,48 @@ std::string _default_device_fp_type(const sycl::device &d)
49
49
}
50
50
}
51
51
52
+ int get_numpy_major_version ()
53
+ {
54
+ namespace py = pybind11;
55
+
56
+ py::module_ numpy = py::module_::import (" numpy" );
57
+ py::str version_string = numpy.attr (" __version__" );
58
+ py::module_ numpy_lib = py::module_::import (" numpy.lib" );
59
+
60
+ py::object numpy_version = numpy_lib.attr (" NumpyVersion" )(version_string);
61
+ int major_version = numpy_version.attr (" major" ).cast <int >();
62
+
63
+ return major_version;
64
+ }
65
+
52
66
std::string _default_device_int_type (const sycl::device &)
53
67
{
54
- return " l" ; // code for numpy.dtype('long') to be consistent
55
- // with NumPy's default integer type across
56
- // platforms.
68
+ const int np_ver = get_numpy_major_version ();
69
+
70
+ if (np_ver >= 2 ) {
71
+ return " i8" ;
72
+ }
73
+ else {
74
+ // code for numpy.dtype('long') to be consistent
75
+ // with NumPy's default integer type across
76
+ // platforms.
77
+ return " l" ;
78
+ }
79
+ }
80
+
81
+ std::string _default_device_uint_type (const sycl::device &)
82
+ {
83
+ const int np_ver = get_numpy_major_version ();
84
+
85
+ if (np_ver >= 2 ) {
86
+ return " u8" ;
87
+ }
88
+ else {
89
+ // code for numpy.dtype('long') to be consistent
90
+ // with NumPy's default integer type across
91
+ // platforms.
92
+ return " L" ;
93
+ }
57
94
}
58
95
59
96
std::string _default_device_complex_type (const sycl::device &d)
@@ -66,15 +103,9 @@ std::string _default_device_complex_type(const sycl::device &d)
66
103
}
67
104
}
68
105
69
- std::string _default_device_bool_type (const sycl::device &)
70
- {
71
- return " b1" ;
72
- }
106
+ std::string _default_device_bool_type (const sycl::device &) { return " b1" ; }
73
107
74
- std::string _default_device_index_type (const sycl::device &)
75
- {
76
- return " i8" ;
77
- }
108
+ std::string _default_device_index_type (const sycl::device &) { return " i8" ; }
78
109
79
110
sycl::device _extract_device (const py::object &arg)
80
111
{
@@ -108,6 +139,12 @@ std::string default_device_int_type(const py::object &arg)
108
139
return _default_device_int_type (d);
109
140
}
110
141
142
+ std::string default_device_uint_type (const py::object &arg)
143
+ {
144
+ const sycl::device &d = _extract_device (arg);
145
+ return _default_device_uint_type (d);
146
+ }
147
+
111
148
std::string default_device_bool_type (const py::object &arg)
112
149
{
113
150
const sycl::device &d = _extract_device (arg);
0 commit comments