Skip to content

Commit b12583a

Browse files
authored
pybind11 type caster for sycl::half (#1655)
* Implements type caster for sycl::half and removes unboxing_helper.hpp `py::cast<sycl_half>` being available makes PythonObjectUnboxer redundant * Apply changes per review to sycl::half caster * Remove unnecessary check for py_err and unneeded PyErr_Clear call per PR feedback
1 parent aaf444e commit b12583a

File tree

4 files changed

+52
-81
lines changed

4 files changed

+52
-81
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,53 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
671671
DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>,
672672
_("dpctl.program.SyclProgram"));
673673
};
674+
675+
/* This type caster associates
676+
* ``sycl::half`` C++ class with Python :class:`float` for the purposes
677+
* of generation of Python bindings by pybind11.
678+
*/
679+
template <> struct type_caster<sycl::half>
680+
{
681+
public:
682+
bool load(handle src, bool convert)
683+
{
684+
double py_value;
685+
686+
if (!src) {
687+
return false;
688+
}
689+
690+
PyObject *source = src.ptr();
691+
692+
if (convert || PyFloat_Check(source)) {
693+
py_value = PyFloat_AsDouble(source);
694+
}
695+
else {
696+
return false;
697+
}
698+
699+
bool py_err = (py_value == double(-1)) && PyErr_Occurred();
700+
701+
if (py_err) {
702+
PyErr_Clear();
703+
if (convert && (PyNumber_Check(source) != 0)) {
704+
auto tmp = reinterpret_steal<object>(PyNumber_Float(source));
705+
return load(tmp, false);
706+
}
707+
return false;
708+
}
709+
value = static_cast<sycl::half>(py_value);
710+
return true;
711+
}
712+
713+
static handle cast(sycl::half src, return_value_policy, handle)
714+
{
715+
return PyFloat_FromDouble(static_cast<double>(src));
716+
}
717+
718+
PYBIND11_TYPE_CASTER(sycl::half, _("float"));
719+
};
720+
674721
} // namespace detail
675722
} // namespace pybind11
676723

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include "utils/type_utils.hpp"
3737

3838
#include "full_ctor.hpp"
39-
#include "unboxing_helper.hpp"
4039

4140
namespace py = pybind11;
4241
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -79,14 +78,7 @@ sycl::event full_contig_impl(sycl::queue &exec_q,
7978
char *dst_p,
8079
const std::vector<sycl::event> &depends)
8180
{
82-
dstTy fill_v;
83-
84-
PythonObjectUnboxer<dstTy> unboxer{};
85-
try {
86-
fill_v = unboxer(py_value);
87-
} catch (const py::error_already_set &e) {
88-
throw;
89-
}
81+
dstTy fill_v = py::cast<dstTy>(py_value);
9082

9183
using dpctl::tensor::kernels::constructors::full_contig_impl;
9284

dpctl/tensor/libtensor/source/linear_sequences.cpp

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include "utils/type_utils.hpp"
3737

3838
#include "linear_sequences.hpp"
39-
#include "unboxing_helper.hpp"
4039

4140
namespace py = pybind11;
4241
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -86,16 +85,8 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
8685
char *array_data,
8786
const std::vector<sycl::event> &depends)
8887
{
89-
Ty start_v;
90-
Ty step_v;
91-
92-
const auto &unboxer = PythonObjectUnboxer<Ty>{};
93-
try {
94-
start_v = unboxer(start);
95-
step_v = unboxer(step);
96-
} catch (const py::error_already_set &e) {
97-
throw;
98-
}
88+
Ty start_v = py::cast<Ty>(start);
89+
Ty step_v = py::cast<Ty>(step);
9990

10091
using dpctl::tensor::kernels::constructors::lin_space_step_impl;
10192

@@ -143,14 +134,8 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q,
143134
char *array_data,
144135
const std::vector<sycl::event> &depends)
145136
{
146-
Ty start_v, end_v;
147-
const auto &unboxer = PythonObjectUnboxer<Ty>{};
148-
try {
149-
start_v = unboxer(start);
150-
end_v = unboxer(end);
151-
} catch (const py::error_already_set &e) {
152-
throw;
153-
}
137+
Ty start_v = py::cast<Ty>(start);
138+
Ty end_v = py::cast<Ty>(end);
154139

155140
using dpctl::tensor::kernels::constructors::lin_space_affine_impl;
156141

dpctl/tensor/libtensor/source/unboxing_helper.hpp

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)