Skip to content

Commit a8fed75

Browse files
rwgkcopybara-github
authored andcommitted
Add absl::Span PyObject* support in absl_casters.h.
PiperOrigin-RevId: 532586991
1 parent 5740192 commit a8fed75

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

pybind11_abseil/absl_casters.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <pybind11/cast.h>
3434
#include <pybind11/pybind11.h>
3535
#include <pybind11/stl.h>
36+
#include <pybind11/type_caster_pyobject_ptr.h>
3637

3738
// Must NOT appear before at least one pybind11 include.
3839
#include <datetime.h> // Python datetime builtin.
@@ -392,10 +393,23 @@ namespace internal {
392393

393394
template <typename T>
394395
static constexpr bool is_buffer_interface_compatible_type =
396+
detail::is_same_ignoring_cvref<T, PyObject*>::value ||
395397
std::is_arithmetic<T>::value ||
396398
std::is_same<T, std::complex<float>>::value ||
397399
std::is_same<T, std::complex<double>>::value;
398400

401+
template <typename T, typename SFINAE = void>
402+
struct format_descriptor_char1 : format_descriptor<T> {};
403+
404+
template <typename T>
405+
struct format_descriptor_char1<
406+
T,
407+
detail::enable_if_t<detail::is_same_ignoring_cvref<T, PyObject*>::value>> {
408+
static constexpr const char c = 'O';
409+
static constexpr const char value[2] = {c, '\0'};
410+
static std::string format() { return std::string(1, c); }
411+
};
412+
399413
template <typename T, typename SFINAE = void>
400414
struct format_descriptor_char2 {
401415
static constexpr const char c = '\0';
@@ -406,7 +420,7 @@ struct format_descriptor_char2<std::complex<T>> : format_descriptor<T> {};
406420

407421
template <typename T>
408422
inline bool buffer_view_matches_format_descriptor(const char* view_format) {
409-
return view_format[0] == format_descriptor<T>::c ||
423+
return view_format[0] == format_descriptor_char1<T>::c ||
410424
(view_format[0] == 'Z' &&
411425
view_format[1] == format_descriptor_char2<T>::c);
412426
}

pybind11_abseil/tests/absl_example.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ NonConstCmplxType SumSpanComplex(absl::Span<CmplxType> input_span) {
273273
return sum;
274274
}
275275

276+
std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
277+
std::string result;
278+
for (auto& i : input_span) result += str(i);
279+
return result;
280+
}
281+
276282
struct ObjectForSpan {
277283
explicit ObjectForSpan(int v) : value(v) {}
278284
int value;
@@ -397,6 +403,7 @@ PYBIND11_MODULE(absl_example, m) {
397403
m.def("sum_span_complex128", &SumSpanComplex<std::complex<double>>);
398404
m.def("sum_span_const_complex128",
399405
&SumSpanComplex<const std::complex<double>>, arg("input_span"));
406+
m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span"));
400407

401408
// Span of objects.
402409
class_<ObjectForSpan>(m, "ObjectForSpan")

pybind11_abseil/tests/absl_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ def test_complex(self, numpy_type, sum_span_fn):
378378
xs = np.array([x * 1j for x in range(10)]).astype(numpy_type)
379379
self.assertEqual(sum_span_fn(xs), 45j)
380380

381+
def test_pass_span_pyobject_ptr(self):
382+
arr = np.array([-3, 'four', 5.0], dtype=object)
383+
self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0')
384+
381385

382386
def make_native_list_of_objects():
383387
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]

0 commit comments

Comments
 (0)