31
31
#include < pybind11/stl.h>
32
32
#include < vector>
33
33
34
+ #include " elementwise_functions_type_utils.hpp"
34
35
#include " simplify_iteration_space.hpp"
35
36
#include " utils/memory_overlap.hpp"
36
37
#include " utils/offset_utils.hpp"
@@ -46,56 +47,7 @@ namespace tensor
46
47
namespace py_internal
47
48
{
48
49
49
- namespace
50
- {
51
- inline py::dtype _dtype_from_typenum (td_ns::typenum_t dst_typenum_t )
52
- {
53
- switch (dst_typenum_t ) {
54
- case td_ns::typenum_t ::BOOL:
55
- return py::dtype (" ?" );
56
- case td_ns::typenum_t ::INT8:
57
- return py::dtype (" i1" );
58
- case td_ns::typenum_t ::UINT8:
59
- return py::dtype (" u1" );
60
- case td_ns::typenum_t ::INT16:
61
- return py::dtype (" i2" );
62
- case td_ns::typenum_t ::UINT16:
63
- return py::dtype (" u2" );
64
- case td_ns::typenum_t ::INT32:
65
- return py::dtype (" i4" );
66
- case td_ns::typenum_t ::UINT32:
67
- return py::dtype (" u4" );
68
- case td_ns::typenum_t ::INT64:
69
- return py::dtype (" i8" );
70
- case td_ns::typenum_t ::UINT64:
71
- return py::dtype (" u8" );
72
- case td_ns::typenum_t ::HALF:
73
- return py::dtype (" f2" );
74
- case td_ns::typenum_t ::FLOAT:
75
- return py::dtype (" f4" );
76
- case td_ns::typenum_t ::DOUBLE:
77
- return py::dtype (" f8" );
78
- case td_ns::typenum_t ::CFLOAT:
79
- return py::dtype (" c8" );
80
- case td_ns::typenum_t ::CDOUBLE:
81
- return py::dtype (" c16" );
82
- default :
83
- throw py::value_error (" Unrecognized dst_typeid" );
84
- }
85
- }
86
-
87
- inline int _result_typeid (int arg_typeid, const int *fn_output_id)
88
- {
89
- if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
90
- throw py::value_error (" Input typeid " + std::to_string (arg_typeid) +
91
- " is outside of expected bounds." );
92
- }
93
-
94
- return fn_output_id[arg_typeid];
95
- }
96
-
97
- } // end of anonymous namespace
98
-
50
+ /* ! @brief Template implementing Python API for unary elementwise functions */
99
51
template <typename output_typesT,
100
52
typename contig_dispatchT,
101
53
typename strided_dispatchT>
@@ -297,6 +249,8 @@ py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
297
249
strided_fn_ev);
298
250
}
299
251
252
+ /* ! @brief Template implementing Python API for querying of type support by
253
+ * unary elementwise functions */
300
254
template <typename output_typesT>
301
255
py::object py_unary_ufunc_result_type (const py::dtype &input_dtype,
302
256
const output_typesT &output_types)
@@ -312,15 +266,17 @@ py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
312
266
throw py::value_error (e.what ());
313
267
}
314
268
269
+ using dpctl::tensor::py_internal::type_utils::_result_typeid;
315
270
int dst_typeid = _result_typeid (src_typeid, output_types);
316
271
317
272
if (dst_typeid < 0 ) {
318
273
auto res = py::none ();
319
274
return py::cast<py::object>(res);
320
275
}
321
276
else {
322
- auto dst_typenum_t = static_cast <td_ns:: typenum_t >(dst_typeid) ;
277
+ using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum ;
323
278
279
+ auto dst_typenum_t = static_cast <td_ns::typenum_t >(dst_typeid);
324
280
auto dt = _dtype_from_typenum (dst_typenum_t );
325
281
326
282
return py::cast<py::object>(dt);
@@ -338,6 +294,8 @@ bool isEqual(Container const &c, std::initializer_list<T> const &l)
338
294
}
339
295
} // namespace
340
296
297
+ /* ! @brief Template implementing Python API for binary elementwise
298
+ * functions */
341
299
template <typename output_typesT,
342
300
typename contig_dispatchT,
343
301
typename strided_dispatchT,
@@ -605,6 +563,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
605
563
strided_fn_ev);
606
564
}
607
565
566
+ /* ! @brief Type querying for binary elementwise functions */
608
567
template <typename output_typesT>
609
568
py::object py_binary_ufunc_result_type (const py::dtype &input1_dtype,
610
569
const py::dtype &input2_dtype,
@@ -636,8 +595,9 @@ py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
636
595
return py::cast<py::object>(res);
637
596
}
638
597
else {
639
- auto dst_typenum_t = static_cast <td_ns:: typenum_t >(dst_typeid) ;
598
+ using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum ;
640
599
600
+ auto dst_typenum_t = static_cast <td_ns::typenum_t >(dst_typeid);
641
601
auto dt = _dtype_from_typenum (dst_typenum_t );
642
602
643
603
return py::cast<py::object>(dt);
0 commit comments