Skip to content

Commit 5740192

Browse files
rwgkcopybara-github
authored andcommitted
Add absl::Span std::complex support in absl_casters.h.
PiperOrigin-RevId: 532576894
1 parent 64a813b commit 5740192

File tree

3 files changed

+60
-6
lines changed

3 files changed

+60
-6
lines changed

pybind11_abseil/absl_casters.h

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <datetime.h> // Python datetime builtin.
3939

4040
#include <cmath>
41+
#include <complex>
4142
#include <cstdint>
4243
#include <tuple>
4344
#include <type_traits>
@@ -385,18 +386,46 @@ template <>
385386
struct type_caster<absl::CivilYear>
386387
: public absl_civil_date_caster<absl::CivilYear> {};
387388

389+
// Using internal namespace to avoid name collisons in case this code is
390+
// accepted upsteam (pybind11).
391+
namespace internal {
392+
393+
template <typename T>
394+
static constexpr bool is_buffer_interface_compatible_type =
395+
std::is_arithmetic<T>::value ||
396+
std::is_same<T, std::complex<float>>::value ||
397+
std::is_same<T, std::complex<double>>::value;
398+
399+
template <typename T, typename SFINAE = void>
400+
struct format_descriptor_char2 {
401+
static constexpr const char c = '\0';
402+
};
403+
404+
template <typename T>
405+
struct format_descriptor_char2<std::complex<T>> : format_descriptor<T> {};
406+
407+
template <typename T>
408+
inline bool buffer_view_matches_format_descriptor(const char* view_format) {
409+
return view_format[0] == format_descriptor<T>::c ||
410+
(view_format[0] == 'Z' &&
411+
view_format[1] == format_descriptor_char2<T>::c);
412+
}
413+
414+
} // namespace internal
415+
388416
// Returns {true, a span referencing the data contained by src} without copying
389417
// or converting the data if possible. Otherwise returns {false, an empty span}.
390-
template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
391-
bool>::type = true>
418+
template <typename T, typename std::enable_if<
419+
internal::is_buffer_interface_compatible_type<T>,
420+
bool>::type = true>
392421
std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
393422
Py_buffer view;
394423
int flags = PyBUF_STRIDES | PyBUF_FORMAT;
395424
if (!std::is_const<T>::value) flags |= PyBUF_WRITABLE;
396425
if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) {
397426
auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); });
398427
if (view.ndim == 1 && view.strides[0] == sizeof(T) &&
399-
view.format[0] == format_descriptor<T>::c) {
428+
internal::buffer_view_matches_format_descriptor<T>(view.format)) {
400429
return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])};
401430
}
402431
} else {
@@ -405,9 +434,9 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
405434
}
406435
return {false, absl::Span<T>()};
407436
}
408-
// If T is not a numeric type, the buffer interface cannot be used.
409-
template <typename T, typename std::enable_if<!std::is_arithmetic<T>::value,
410-
bool>::type = true>
437+
template <typename T, typename std::enable_if<
438+
!internal::is_buffer_interface_compatible_type<T>,
439+
bool>::type = true>
411440
constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
412441
return {false, absl::Span<T>()};
413442
}

pybind11_abseil/tests/absl_example.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
// All rights reserved. Use of this source code is governed by a
44
// BSD-style license that can be found in the LICENSE file.
55

6+
#include <pybind11/complex.h>
67
#include <pybind11/pybind11.h>
78
#include <pybind11/stl.h>
89
#include <pybind11/stl_bind.h>
910

11+
#include <complex>
1012
#include <cstddef>
1113
#include <vector>
1214

@@ -263,6 +265,14 @@ void FillSpan(int value, absl::Span<int> output_span) {
263265
for (auto& i : output_span) i = value;
264266
}
265267

268+
template <typename CmplxType, typename NonConstCmplxType =
269+
typename std::remove_const<CmplxType>::type>
270+
NonConstCmplxType SumSpanComplex(absl::Span<CmplxType> input_span) {
271+
NonConstCmplxType sum = 0;
272+
for (auto& i : input_span) sum += i;
273+
return sum;
274+
}
275+
266276
struct ObjectForSpan {
267277
explicit ObjectForSpan(int v) : value(v) {}
268278
int value;
@@ -382,6 +392,11 @@ PYBIND11_MODULE(absl_example, m) {
382392
// Non-const spans can never be converted, so `output_span` could be marked as
383393
// `noconvert`, but that would be redundant (so test that it is not needed).
384394
m.def("fill_span", &FillSpan, arg("value"), arg("output_span"));
395+
m.def("sum_span_complex64", &SumSpanComplex<std::complex<float>>);
396+
m.def("sum_span_const_complex64", &SumSpanComplex<const std::complex<float>>);
397+
m.def("sum_span_complex128", &SumSpanComplex<std::complex<double>>);
398+
m.def("sum_span_const_complex128",
399+
&SumSpanComplex<const std::complex<double>>, arg("input_span"));
385400

386401
// Span of objects.
387402
class_<ObjectForSpan>(m, "ObjectForSpan")

pybind11_abseil/tests/absl_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ def test_fill_span_fails_from(self, values):
368368
with self.assertRaises(TypeError):
369369
absl_example.fill_span(42, values)
370370

371+
@parameterized.parameters(
372+
('complex64', absl_example.sum_span_complex64),
373+
('complex64', absl_example.sum_span_const_complex64),
374+
('complex128', absl_example.sum_span_complex128),
375+
('complex128', absl_example.sum_span_const_complex128),
376+
)
377+
def test_complex(self, numpy_type, sum_span_fn):
378+
xs = np.array([x * 1j for x in range(10)]).astype(numpy_type)
379+
self.assertEqual(sum_span_fn(xs), 45j)
380+
371381

372382
def make_native_list_of_objects():
373383
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]

0 commit comments

Comments
 (0)