|
26 | 26 | #include <algorithm>
|
27 | 27 | #include <complex>
|
28 | 28 | #include <cstdint>
|
29 |
| -#include <pybind11/complex.h> |
30 | 29 | #include <pybind11/pybind11.h>
|
31 | 30 | #include <pybind11/stl.h>
|
32 | 31 | #include <thread>
|
|
37 | 36 | #include "copy_and_cast_usm_to_usm.hpp"
|
38 | 37 | #include "copy_for_reshape.hpp"
|
39 | 38 | #include "copy_numpy_ndarray_into_usm_ndarray.hpp"
|
| 39 | +#include "device_support_queries.hpp" |
40 | 40 | #include "eye_ctor.hpp"
|
41 | 41 | #include "full_ctor.hpp"
|
42 | 42 | #include "linear_sequences.hpp"
|
@@ -102,36 +102,6 @@ void init_dispatch_vectors(void)
|
102 | 102 | return;
|
103 | 103 | }
|
104 | 104 |
|
105 |
| -std::string get_default_device_fp_type(sycl::device d) |
106 |
| -{ |
107 |
| - if (d.has(sycl::aspect::fp64)) { |
108 |
| - return "f8"; |
109 |
| - } |
110 |
| - else { |
111 |
| - return "f4"; |
112 |
| - } |
113 |
| -} |
114 |
| - |
115 |
| -std::string get_default_device_int_type(sycl::device) |
116 |
| -{ |
117 |
| - return "i8"; |
118 |
| -} |
119 |
| - |
120 |
| -std::string get_default_device_complex_type(sycl::device d) |
121 |
| -{ |
122 |
| - if (d.has(sycl::aspect::fp64)) { |
123 |
| - return "c16"; |
124 |
| - } |
125 |
| - else { |
126 |
| - return "c8"; |
127 |
| - } |
128 |
| -} |
129 |
| - |
130 |
| -std::string get_default_device_bool_type(sycl::device) |
131 |
| -{ |
132 |
| - return "b1"; |
133 |
| -} |
134 |
| - |
135 | 105 | } // namespace
|
136 | 106 |
|
137 | 107 | PYBIND11_MODULE(_tensor_impl, m)
|
@@ -209,57 +179,43 @@ PYBIND11_MODULE(_tensor_impl, m)
|
209 | 179 | py::arg("k"), py::arg("dst"), py::arg("sycl_queue"),
|
210 | 180 | py::arg("depends") = py::list());
|
211 | 181 |
|
212 |
| - m.def("default_device_fp_type", [](sycl::queue q) -> std::string { |
213 |
| - return get_default_device_fp_type(q.get_device()); |
214 |
| - }); |
215 |
| - m.def("default_device_fp_type_device", [](sycl::device dev) -> std::string { |
216 |
| - return get_default_device_fp_type(dev); |
217 |
| - }); |
218 |
| - |
219 |
| - m.def("default_device_int_type", [](sycl::queue q) -> std::string { |
220 |
| - return get_default_device_int_type(q.get_device()); |
221 |
| - }); |
222 |
| - m.def("default_device_int_type_device", |
223 |
| - [](sycl::device dev) -> std::string { |
224 |
| - return get_default_device_int_type(dev); |
225 |
| - }); |
226 |
| - |
227 |
| - m.def("default_device_bool_type", [](sycl::queue q) -> std::string { |
228 |
| - return get_default_device_bool_type(q.get_device()); |
229 |
| - }); |
230 |
| - m.def("default_device_bool_type_device", |
231 |
| - [](sycl::device dev) -> std::string { |
232 |
| - return get_default_device_bool_type(dev); |
233 |
| - }); |
234 |
| - |
235 |
| - m.def("default_device_complex_type", [](sycl::queue q) -> std::string { |
236 |
| - return get_default_device_complex_type(q.get_device()); |
237 |
| - }); |
238 |
| - m.def("default_device_complex_type_device", |
239 |
| - [](sycl::device dev) -> std::string { |
240 |
| - return get_default_device_complex_type(dev); |
241 |
| - }); |
242 |
| - m.def( |
243 |
| - "_tril", |
244 |
| - [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, |
245 |
| - py::ssize_t k, sycl::queue exec_q, |
246 |
| - const std::vector<sycl::event> depends) |
247 |
| - -> std::pair<sycl::event, sycl::event> { |
248 |
| - return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends); |
249 |
| - }, |
250 |
| - "Tril helper function.", py::arg("src"), py::arg("dst"), |
251 |
| - py::arg("k") = 0, py::arg("sycl_queue"), |
252 |
| - py::arg("depends") = py::list()); |
| 182 | + m.def("default_device_fp_type", |
| 183 | + dpctl::tensor::py_internal::default_device_fp_type, |
| 184 | + "Gives default floating point type supported by device.", |
| 185 | + py::arg("dev")); |
| 186 | + |
| 187 | + m.def("default_device_int_type", |
| 188 | + dpctl::tensor::py_internal::default_device_int_type, |
| 189 | + "Gives default integer type supported by device.", py::arg("dev")); |
| 190 | + |
| 191 | + m.def("default_device_bool_type", |
| 192 | + dpctl::tensor::py_internal::default_device_bool_type, |
| 193 | + "Gives default boolean type supported by device.", py::arg("dev")); |
| 194 | + |
| 195 | + m.def("default_device_complex_type", |
| 196 | + dpctl::tensor::py_internal::default_device_complex_type, |
| 197 | + "Gives default complex floating point type support by device.", |
| 198 | + py::arg("dev")); |
| 199 | + |
| 200 | + auto tril_fn = [](dpctl::tensor::usm_ndarray src, |
| 201 | + dpctl::tensor::usm_ndarray dst, py::ssize_t k, |
| 202 | + sycl::queue exec_q, |
| 203 | + const std::vector<sycl::event> depends) |
| 204 | + -> std::pair<sycl::event, sycl::event> { |
| 205 | + return usm_ndarray_triul(exec_q, src, dst, 'l', k, depends); |
| 206 | + }; |
| 207 | + m.def("_tril", tril_fn, "Tril helper function.", py::arg("src"), |
| 208 | + py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"), |
| 209 | + py::arg("depends") = py::list()); |
253 | 210 |
|
254 |
| - m.def( |
255 |
| - "_triu", |
256 |
| - [](dpctl::tensor::usm_ndarray src, dpctl::tensor::usm_ndarray dst, |
257 |
| - py::ssize_t k, sycl::queue exec_q, |
258 |
| - const std::vector<sycl::event> depends) |
259 |
| - -> std::pair<sycl::event, sycl::event> { |
260 |
| - return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends); |
261 |
| - }, |
262 |
| - "Triu helper function.", py::arg("src"), py::arg("dst"), |
263 |
| - py::arg("k") = 0, py::arg("sycl_queue"), |
264 |
| - py::arg("depends") = py::list()); |
| 211 | + auto triu_fn = [](dpctl::tensor::usm_ndarray src, |
| 212 | + dpctl::tensor::usm_ndarray dst, py::ssize_t k, |
| 213 | + sycl::queue exec_q, |
| 214 | + const std::vector<sycl::event> depends) |
| 215 | + -> std::pair<sycl::event, sycl::event> { |
| 216 | + return usm_ndarray_triul(exec_q, src, dst, 'u', k, depends); |
| 217 | + }; |
| 218 | + m.def("_triu", triu_fn, "Triu helper function.", py::arg("src"), |
| 219 | + py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"), |
| 220 | + py::arg("depends") = py::list()); |
265 | 221 | }
|
0 commit comments