Skip to content

Commit b9fb316

Browse files
MaartenBaertpre-commit-ci[bot]rwgk
authored
Add support for array_t<handle> and array_t<object> (#5427)
* Add support for array_t<handle> and array_t<object> * style: pre-commit fixes * Remove loops that aren't strictly needed * Fix compiler warning * Disable GC-dependent checks when running on pypy or graalpy * style: pre-commit fixes * Remove PyValueHolder counter again * Move tests to templates to avoid code duplication * Rerun pre-commit * Restore import that was erroneously removed by pre-commit * Reduce code duplication with more template magic * Bring back `.attr("value")` in `return_array_cpp_loop()` This was meant to further stress-test correctness of refcount handling. All modified test functions were manually leak-checked (`while True`, top command, Python 3.12.3, Ubuntu 24.01, gcc 13.2.0). --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent 08095d9 commit b9fb316

File tree

3 files changed

+105
-43
lines changed

3 files changed

+105
-43
lines changed

include/pybind11/numpy.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1428,7 +1428,11 @@ struct npy_format_descriptor<
14281428
};
14291429

14301430
template <typename T>
1431-
struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value>> {
1431+
struct npy_format_descriptor<
1432+
T,
1433+
enable_if_t<is_same_ignoring_cvref<T, PyObject *>::value
1434+
|| ((std::is_same<T, handle>::value || std::is_same<T, object>::value)
1435+
&& sizeof(T) == sizeof(PyObject *))>> {
14321436
static constexpr auto name = const_name("object");
14331437

14341438
static constexpr int value = npy_api::NPY_OBJECT_;

tests/test_numpy_array.cpp

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,55 @@ py::handle auxiliaries(T &&r, T2 &&r2) {
156156
return l.release();
157157
}
158158

159+
template <typename PyObjectType>
160+
PyObjectType convert_to_pyobjecttype(py::object obj);
161+
162+
template <>
163+
PyObject *convert_to_pyobjecttype<PyObject *>(py::object obj) {
164+
return obj.release().ptr();
165+
}
166+
167+
template <>
168+
py::handle convert_to_pyobjecttype<py::handle>(py::object obj) {
169+
return obj.release();
170+
}
171+
172+
template <>
173+
py::object convert_to_pyobjecttype<py::object>(py::object obj) {
174+
return obj;
175+
}
176+
177+
template <typename PyObjectType>
178+
std::string pass_array_return_sum_str_values(const py::array_t<PyObjectType> &objs) {
179+
std::string sum_str_values;
180+
for (const auto &obj : objs) {
181+
sum_str_values += py::str(obj.attr("value"));
182+
}
183+
return sum_str_values;
184+
}
185+
186+
template <typename PyObjectType>
187+
py::list pass_array_return_as_list(const py::array_t<PyObjectType> &objs) {
188+
return objs;
189+
}
190+
191+
template <typename PyObjectType>
192+
py::array_t<PyObjectType> return_array_cpp_loop(const py::list &objs) {
193+
py::size_t arr_size = py::len(objs);
194+
py::array_t<PyObjectType> arr_from_list(static_cast<py::ssize_t>(arr_size));
195+
PyObjectType *data = arr_from_list.mutable_data();
196+
for (py::size_t i = 0; i < arr_size; i++) {
197+
assert(!data[i]);
198+
data[i] = convert_to_pyobjecttype<PyObjectType>(objs[i].attr("value"));
199+
}
200+
return arr_from_list;
201+
}
202+
203+
template <typename PyObjectType>
204+
py::array_t<PyObjectType> return_array_from_list(const py::list &objs) {
205+
return objs;
206+
}
207+
159208
// note: declaration at local scope would create a dangling reference!
160209
static int data_i = 42;
161210

@@ -520,28 +569,21 @@ TEST_SUBMODULE(numpy_array, sm) {
520569
sm.def("round_trip_float", [](double d) { return d; });
521570

522571
sm.def("pass_array_pyobject_ptr_return_sum_str_values",
523-
[](const py::array_t<PyObject *> &objs) {
524-
std::string sum_str_values;
525-
for (const auto &obj : objs) {
526-
sum_str_values += py::str(obj.attr("value"));
527-
}
528-
return sum_str_values;
529-
});
530-
531-
sm.def("pass_array_pyobject_ptr_return_as_list",
532-
[](const py::array_t<PyObject *> &objs) -> py::list { return objs; });
533-
534-
sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) {
535-
py::size_t arr_size = py::len(objs);
536-
py::array_t<PyObject *> arr_from_list(static_cast<py::ssize_t>(arr_size));
537-
PyObject **data = arr_from_list.mutable_data();
538-
for (py::size_t i = 0; i < arr_size; i++) {
539-
assert(data[i] == nullptr);
540-
data[i] = py::cast<PyObject *>(objs[i].attr("value"));
541-
}
542-
return arr_from_list;
543-
});
544-
545-
sm.def("return_array_pyobject_ptr_from_list",
546-
[](const py::list &objs) -> py::array_t<PyObject *> { return objs; });
572+
pass_array_return_sum_str_values<PyObject *>);
573+
sm.def("pass_array_handle_return_sum_str_values",
574+
pass_array_return_sum_str_values<py::handle>);
575+
sm.def("pass_array_object_return_sum_str_values",
576+
pass_array_return_sum_str_values<py::object>);
577+
578+
sm.def("pass_array_pyobject_ptr_return_as_list", pass_array_return_as_list<PyObject *>);
579+
sm.def("pass_array_handle_return_as_list", pass_array_return_as_list<py::handle>);
580+
sm.def("pass_array_object_return_as_list", pass_array_return_as_list<py::object>);
581+
582+
sm.def("return_array_pyobject_ptr_cpp_loop", return_array_cpp_loop<PyObject *>);
583+
sm.def("return_array_handle_cpp_loop", return_array_cpp_loop<py::handle>);
584+
sm.def("return_array_object_cpp_loop", return_array_cpp_loop<py::object>);
585+
586+
sm.def("return_array_pyobject_ptr_from_list", return_array_from_list<PyObject *>);
587+
sm.def("return_array_handle_from_list", return_array_from_list<py::handle>);
588+
sm.def("return_array_object_from_list", return_array_from_list<py::object>);
547589
}

tests/test_numpy_array.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -629,45 +629,61 @@ def UnwrapPyValueHolder(vhs):
629629
return [vh.value for vh in vhs]
630630

631631

632-
def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray():
632+
PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS = [
633+
m.pass_array_pyobject_ptr_return_sum_str_values,
634+
m.pass_array_handle_return_sum_str_values,
635+
m.pass_array_object_return_sum_str_values,
636+
]
637+
638+
639+
@pytest.mark.parametrize(
640+
"pass_array", PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS
641+
)
642+
def test_pass_array_object_return_sum_str_values_ndarray(pass_array):
633643
# Intentionally all temporaries, do not change.
634644
assert (
635-
m.pass_array_pyobject_ptr_return_sum_str_values(
636-
np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)
637-
)
645+
pass_array(np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object))
638646
== "-3four5.0"
639647
)
640648

641649

642-
def test_pass_array_pyobject_ptr_return_sum_str_values_list():
650+
@pytest.mark.parametrize(
651+
"pass_array", PASS_ARRAY_PYOBJECT_RETURN_SUM_STR_VALUES_FUNCTIONS
652+
)
653+
def test_pass_array_object_return_sum_str_values_list(pass_array):
643654
# Intentionally all temporaries, do not change.
644-
assert (
645-
m.pass_array_pyobject_ptr_return_sum_str_values(
646-
WrapWithPyValueHolder(2, "three", -4.0)
647-
)
648-
== "2three-4.0"
649-
)
655+
assert pass_array(WrapWithPyValueHolder(2, "three", -4.0)) == "2three-4.0"
650656

651657

652-
def test_pass_array_pyobject_ptr_return_as_list():
658+
@pytest.mark.parametrize(
659+
"pass_array",
660+
[
661+
m.pass_array_pyobject_ptr_return_as_list,
662+
m.pass_array_handle_return_as_list,
663+
m.pass_array_object_return_as_list,
664+
],
665+
)
666+
def test_pass_array_object_return_as_list(pass_array):
653667
# Intentionally all temporaries, do not change.
654668
assert UnwrapPyValueHolder(
655-
m.pass_array_pyobject_ptr_return_as_list(
656-
np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)
657-
)
669+
pass_array(np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object))
658670
) == [-1, "two", 3.0]
659671

660672

661673
@pytest.mark.parametrize(
662-
("return_array_pyobject_ptr", "unwrap"),
674+
("return_array", "unwrap"),
663675
[
664676
(m.return_array_pyobject_ptr_cpp_loop, list),
677+
(m.return_array_handle_cpp_loop, list),
678+
(m.return_array_object_cpp_loop, list),
665679
(m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder),
680+
(m.return_array_handle_from_list, UnwrapPyValueHolder),
681+
(m.return_array_object_from_list, UnwrapPyValueHolder),
666682
],
667683
)
668-
def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap):
684+
def test_return_array_object_cpp_loop(return_array, unwrap):
669685
# Intentionally all temporaries, do not change.
670-
arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0))
686+
arr_from_list = return_array(WrapWithPyValueHolder(6, "seven", -8.0))
671687
assert isinstance(arr_from_list, np.ndarray)
672688
assert arr_from_list.dtype == np.dtype("O")
673689
assert unwrap(arr_from_list) == [6, "seven", -8.0]

0 commit comments

Comments
 (0)