38
38
#include < datetime.h> // Python datetime builtin.
39
39
40
40
#include < cmath>
41
+ #include < complex>
41
42
#include < cstdint>
42
43
#include < tuple>
43
44
#include < type_traits>
@@ -385,18 +386,46 @@ template <>
385
386
struct type_caster <absl::CivilYear>
386
387
: public absl_civil_date_caster<absl::CivilYear> {};
387
388
389
+ // Using internal namespace to avoid name collisons in case this code is
390
+ // accepted upsteam (pybind11).
391
+ namespace internal {
392
+
393
+ template <typename T>
394
+ static constexpr bool is_buffer_interface_compatible_type =
395
+ std::is_arithmetic<T>::value ||
396
+ std::is_same<T, std::complex<float >>::value ||
397
+ std::is_same<T, std::complex<double >>::value;
398
+
399
+ template <typename T, typename SFINAE = void >
400
+ struct format_descriptor_char2 {
401
+ static constexpr const char c = ' \0 ' ;
402
+ };
403
+
404
+ template <typename T>
405
+ struct format_descriptor_char2 <std::complex<T>> : format_descriptor<T> {};
406
+
407
+ template <typename T>
408
+ inline bool buffer_view_matches_format_descriptor (const char * view_format) {
409
+ return view_format[0 ] == format_descriptor<T>::c ||
410
+ (view_format[0 ] == ' Z' &&
411
+ view_format[1 ] == format_descriptor_char2<T>::c);
412
+ }
413
+
414
+ } // namespace internal
415
+
388
416
// Returns {true, a span referencing the data contained by src} without copying
389
417
// or converting the data if possible. Otherwise returns {false, an empty span}.
390
- template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
391
- bool >::type = true >
418
+ template <typename T, typename std::enable_if<
419
+ internal::is_buffer_interface_compatible_type<T>,
420
+ bool >::type = true >
392
421
std::tuple<bool , absl::Span<T>> LoadSpanFromBuffer (handle src) {
393
422
Py_buffer view;
394
423
int flags = PyBUF_STRIDES | PyBUF_FORMAT;
395
424
if (!std::is_const<T>::value) flags |= PyBUF_WRITABLE;
396
425
if (PyObject_GetBuffer (src.ptr (), &view, flags) == 0 ) {
397
426
auto cleanup = absl::MakeCleanup ([&view] { PyBuffer_Release (&view); });
398
427
if (view.ndim == 1 && view.strides [0 ] == sizeof (T) &&
399
- view. format [ 0 ] == format_descriptor <T>::c ) {
428
+ internal::buffer_view_matches_format_descriptor <T>(view. format ) ) {
400
429
return {true , absl::MakeSpan (static_cast <T*>(view.buf ), view.shape [0 ])};
401
430
}
402
431
} else {
@@ -405,9 +434,9 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
405
434
}
406
435
return {false , absl::Span<T>()};
407
436
}
408
- // If T is not a numeric type, the buffer interface cannot be used.
409
- template < typename T, typename std::enable_if<!std::is_arithmetic <T>::value ,
410
- bool >::type = true >
437
+ template < typename T, typename std::enable_if<
438
+ !internal::is_buffer_interface_compatible_type <T>,
439
+ bool >::type = true >
411
440
constexpr std::tuple<bool , absl::Span<T>> LoadSpanFromBuffer (handle src) {
412
441
return {false , absl::Span<T>()};
413
442
}
0 commit comments