Skip to content

Commit 906b5f2

Browse files
numpy_dtype_user: Pave way for downstream libs to define user dtypes
1 parent 8ad6a89 commit 906b5f2

File tree

9 files changed

+271
-40
lines changed

9 files changed

+271
-40
lines changed

include/pybind11/cast.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1858,13 +1858,20 @@ template <typename T> struct move_if_unreferenced<T, enable_if_t<all_of<
18581858
>::value>> : std::true_type {};
18591859
template <typename T> using move_never = negation<move_common<T>>;
18601860

1861+
template <typename type, typename SFINAE = void>
1862+
struct cast_is_known_safe : public std::false_type {};
1863+
1864+
template <typename type>
1865+
struct cast_is_known_safe<type,
1866+
enable_if_t<std::is_base_of<type_caster_generic, make_caster<type>>::value>> : public std::true_type {};
1867+
18611868
// Detect whether returning a `type` from a cast on type's type_caster is going to result in a
18621869
// reference or pointer to a local variable of the type_caster. Basically, only
18631870
// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe;
18641871
// everything else returns a reference/pointer to a local variable.
18651872
template <typename type> using cast_is_temporary_value_reference = bool_constant<
18661873
(std::is_reference<type>::value || std::is_pointer<type>::value) &&
1867-
!std::is_base_of<type_caster_generic, make_caster<type>>::value
1874+
!cast_is_known_safe<type>::value
18681875
>;
18691876

18701877
// When a value returned from a C++ function is being cast back to Python, we almost always want to

include/pybind11/detail/internals.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ struct internals {
8787
/// Additional type information which does not fit into the PyTypeObject.
8888
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
8989
struct type_info {
90+
using implicit_conversion_func = PyObject *(*)(PyObject *, PyTypeObject *);
91+
9092
PyTypeObject *type;
9193
const std::type_info *cpptype;
9294
size_t type_size, holder_size_in_ptrs;
9395
void *(*operator_new)(size_t);
9496
void (*init_instance)(instance *, holder_erased);
9597
void (*dealloc)(value_and_holder &v_h);
96-
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
98+
std::vector<implicit_conversion_func> implicit_conversions;
9799
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
98100
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
99101
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;

include/pybind11/eigen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ template <typename props> handle eigen_array_cast(typename props::Type const &sr
220220
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
221221
array a;
222222
using Scalar = typename props::Type::Scalar;
223-
bool is_pyobject = static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
223+
bool is_pyobject = is_pyobject_dtype<Scalar>::value;
224224

225225
if (!is_pyobject) {
226226
if (props::vector)

include/pybind11/numpy.h

Lines changed: 223 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,105 @@ struct PyVoidScalarObject_Proxy {
7575
PyObject *base;
7676
};
7777

78+
// UFuncs.
79+
using npy_intp = Py_intptr_t;
80+
81+
typedef void (*PyUFuncGenericFunction)(
82+
char **args, npy_intp *dimensions, npy_intp *strides, void *innerloopdata);
83+
84+
typedef void (PyArray_VectorUnaryFunc)(
85+
void* from_, void* to_, npy_intp n, void* fromarr, void* toarr);
86+
87+
typedef struct {
88+
PyObject_HEAD
89+
int nin;
90+
int nout;
91+
int nargs;
92+
int identity;
93+
PyUFuncGenericFunction *functions;
94+
void **data;
95+
int ntypes;
96+
int reserved1;
97+
const char *name;
98+
char *types;
99+
const char *doc;
100+
void *ptr;
101+
PyObject *obj;
102+
PyObject *userloops;
103+
uint32_t *op_flags;
104+
uint32_t *iter_flags;
105+
} PyUFuncObject;
106+
107+
// Manually defined :(
108+
constexpr int NPY_NTYPES_ABI_COMPATIBLE = 21;
109+
constexpr int NPY_NSORTS = 3;
110+
111+
// TODO(eric.cousineau): Fill this out as needed for type safety.
112+
// TODO(eric.cousineau): Do not define these if NPY headers are present (for debugging).
113+
using PyArray_GetItemFunc = void;
114+
using PyArray_SetItemFunc = void;
115+
using PyArray_CopySwapNFunc = void;
116+
using PyArray_CopySwapFunc = void;
117+
using PyArray_CompareFunc = void;
118+
using PyArray_ArgFunc = void;
119+
using PyArray_DotFunc = void;
120+
using PyArray_ScanFunc = void;
121+
using PyArray_FromStrFunc = void;
122+
using PyArray_NonzeroFunc = void;
123+
using PyArray_FillFunc = void;
124+
using PyArray_FillWithScalarFunc = void;
125+
using PyArray_SortFunc = void;
126+
using PyArray_ArgSortFunc = void;
127+
using PyArray_ScalarKindFunc = void;
128+
using PyArray_FastClipFunc = void;
129+
using PyArray_FastPutmaskFunc = void;
130+
using PyArray_FastTakeFunc = void;
131+
using PyArray_ArgFunc = void;
132+
133+
typedef struct {
134+
PyArray_VectorUnaryFunc *cast[NPY_NTYPES_ABI_COMPATIBLE];
135+
PyArray_GetItemFunc *getitem;
136+
PyArray_SetItemFunc *setitem;
137+
PyArray_CopySwapNFunc *copyswapn;
138+
PyArray_CopySwapFunc *copyswap;
139+
PyArray_CompareFunc *compare;
140+
PyArray_ArgFunc *argmax;
141+
PyArray_DotFunc *dotfunc;
142+
PyArray_ScanFunc *scanfunc;
143+
PyArray_FromStrFunc *fromstr;
144+
PyArray_NonzeroFunc *nonzero;
145+
PyArray_FillFunc *fill;
146+
PyArray_FillWithScalarFunc *fillwithscalar;
147+
PyArray_SortFunc *sort[NPY_NSORTS];
148+
PyArray_ArgSortFunc *argsort[NPY_NSORTS];
149+
PyObject *castdict;
150+
PyArray_ScalarKindFunc *scalarkind;
151+
int **cancastscalarkindto;
152+
int *cancastto;
153+
PyArray_FastClipFunc *fastclip;
154+
PyArray_FastPutmaskFunc *fastputmask;
155+
PyArray_FastTakeFunc *fasttake;
156+
PyArray_ArgFunc *argmin;
157+
} PyArray_ArrFuncs;
158+
159+
using PyArray_ArrayDescr = void;
160+
161+
typedef struct {
162+
PyObject_HEAD
163+
PyTypeObject *typeobj;
164+
char kind;
165+
char type;
166+
char byteorder;
167+
char unused;
168+
int flags;
169+
int type_num;
170+
int elsize;
171+
int alignment;
172+
PyArray_ArrayDescr *subarray;
173+
PyObject *fields;
174+
PyArray_ArrFuncs *f;
175+
} PyArray_Descr;
176+
78177
struct numpy_type_info {
79178
PyObject* dtype_ptr;
80179
std::string format_str;
@@ -109,14 +208,16 @@ inline numpy_internals& get_numpy_internals() {
109208
}
110209

111210
struct npy_api {
112-
enum constants {
211+
enum constants : int {
212+
// Array properties
113213
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
114214
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
115215
NPY_ARRAY_OWNDATA_ = 0x0004,
116216
NPY_ARRAY_FORCECAST_ = 0x0010,
117217
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
118218
NPY_ARRAY_ALIGNED_ = 0x0100,
119219
NPY_ARRAY_WRITEABLE_ = 0x0400,
220+
// Dtypes
120221
NPY_BOOL_ = 0,
121222
NPY_BYTE_, NPY_UBYTE_,
122223
NPY_SHORT_, NPY_USHORT_,
@@ -126,9 +227,27 @@ struct npy_api {
126227
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
127228
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
128229
NPY_OBJECT_ = 17,
129-
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
230+
NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
231+
NPY_USERDEF_ = 256,
232+
// Descriptor flags
233+
NPY_NEEDS_INIT_ = 0x08,
234+
NPY_NEEDS_PYAPI_ = 0x10,
235+
NPY_USE_GETITEM_ = 0x20,
236+
NPY_USE_SETITEM_ = 0x40,
237+
// UFunc
238+
PyUFunc_None_ = -1,
130239
};
131240

241+
typedef enum {
242+
NPY_NOSCALAR_ = -1,
243+
NPY_BOOL_SCALAR_,
244+
NPY_INTPOS_SCALAR_,
245+
NPY_INTNEG_SCALAR_,
246+
NPY_FLOAT_SCALAR_,
247+
NPY_COMPLEX_SCALAR_,
248+
NPY_OBJECT_SCALAR_
249+
} NPY_SCALARKIND;
250+
132251
typedef struct {
133252
Py_intptr_t *ptr;
134253
int len;
@@ -146,6 +265,7 @@ struct npy_api {
146265
return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
147266
}
148267

268+
// Multiarray.
149269
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
150270
PyObject *(*PyArray_DescrFromType_)(int);
151271
PyObject *(*PyArray_NewFromDescr_)
@@ -166,8 +286,29 @@ struct npy_api {
166286
PyObject *(*PyArray_Squeeze_)(PyObject *);
167287
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
168288
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
289+
290+
// - Dtypes
291+
PyTypeObject* PyGenericArrType_Type_;
292+
int (*PyArray_RegisterDataType_)(PyArray_Descr* dtype);
293+
int (*PyArray_RegisterCastFunc_)(PyArray_Descr* descr, int totype, PyArray_VectorUnaryFunc* castfunc);
294+
int (*PyArray_RegisterCanCast_)(PyArray_Descr* descr, int totype, NPY_SCALARKIND scalar);
295+
void (*PyArray_InitArrFuncs_)(PyArray_ArrFuncs *f);
296+
297+
// UFuncs.
298+
PyObject* (*PyUFunc_FromFuncAndData_)(
299+
PyUFuncGenericFunction* func, void** data, char* types, int ntypes,
300+
int nin, int nout, int identity, char* name, char* doc, int unused);
301+
302+
int (*PyUFunc_RegisterLoopForType_)(
303+
PyUFuncObject* ufunc, int usertype, PyUFuncGenericFunction function, int* arg_types, void* data);
304+
305+
int (*PyUFunc_ReplaceLoopBySignature_)(
306+
PyUFuncObject *func, PyUFuncGenericFunction newfunc,
307+
int *signature, PyUFuncGenericFunction *oldfunc);
169308
private:
309+
// TODO(eric.cousineau): Rename to `items` or something, since this now applies to types.
170310
enum functions {
311+
// multiarray
171312
API_PyArray_GetNDArrayCFeatureVersion = 211,
172313
API_PyArray_Type = 2,
173314
API_PyArrayDescr_Type = 3,
@@ -184,38 +325,68 @@ struct npy_api {
184325
API_PyArray_EquivTypes = 182,
185326
API_PyArray_GetArrayParamsFromObject = 278,
186327
API_PyArray_Squeeze = 136,
187-
API_PyArray_SetBaseObject = 282
328+
API_PyArray_SetBaseObject = 282,
329+
// - DTypes
330+
API_PyGenericArrType_Type = 10,
331+
API_PyArray_RegisterDataType = 192,
332+
API_PyArray_RegisterCastFunc = 193,
333+
API_PyArray_RegisterCanCast = 194,
334+
API_PyArray_InitArrFuncs = 195,
335+
// umath
336+
API_PyUFunc_FromFuncAndData = 1,
337+
API_PyUFunc_RegisterLoopForType = 2,
338+
API_PyUFunc_ReplaceLoopBySignature = 30,
188339
};
189340

190-
static npy_api lookup() {
191-
module m = module::import("numpy.core.multiarray");
192-
auto c = m.attr("_ARRAY_API");
341+
static void** get_api_ptr(object c) {
193342
#if PY_MAJOR_VERSION >= 3
194-
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
343+
return (void **) PyCapsule_GetPointer(c.ptr(), NULL);
195344
#else
196-
void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
345+
return (void **) PyCObject_AsVoidPtr(c.ptr());
197346
#endif
347+
}
348+
349+
static npy_api lookup() {
198350
npy_api api;
199351
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
200-
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
201-
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
202-
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
203-
DECL_NPY_API(PyArray_Type);
204-
DECL_NPY_API(PyVoidArrType_Type);
205-
DECL_NPY_API(PyArrayDescr_Type);
206-
DECL_NPY_API(PyArray_DescrFromType);
207-
DECL_NPY_API(PyArray_DescrFromScalar);
208-
DECL_NPY_API(PyArray_FromAny);
209-
DECL_NPY_API(PyArray_Resize);
210-
DECL_NPY_API(PyArray_CopyInto);
211-
DECL_NPY_API(PyArray_NewCopy);
212-
DECL_NPY_API(PyArray_NewFromDescr);
213-
DECL_NPY_API(PyArray_DescrNewFromType);
214-
DECL_NPY_API(PyArray_DescrConverter);
215-
DECL_NPY_API(PyArray_EquivTypes);
216-
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
217-
DECL_NPY_API(PyArray_Squeeze);
218-
DECL_NPY_API(PyArray_SetBaseObject);
352+
// multiarray -> _ARRAY_API
353+
{
354+
module multiarray = module::import("numpy.core.multiarray");
355+
auto api_ptr = get_api_ptr(multiarray.attr("_ARRAY_API"));
356+
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
357+
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
358+
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
359+
DECL_NPY_API(PyArray_Type);
360+
DECL_NPY_API(PyVoidArrType_Type);
361+
DECL_NPY_API(PyArrayDescr_Type);
362+
DECL_NPY_API(PyArray_DescrFromType);
363+
DECL_NPY_API(PyArray_DescrFromScalar);
364+
DECL_NPY_API(PyArray_FromAny);
365+
DECL_NPY_API(PyArray_Resize);
366+
DECL_NPY_API(PyArray_CopyInto);
367+
DECL_NPY_API(PyArray_NewCopy);
368+
DECL_NPY_API(PyArray_NewFromDescr);
369+
DECL_NPY_API(PyArray_DescrNewFromType);
370+
DECL_NPY_API(PyArray_DescrConverter);
371+
DECL_NPY_API(PyArray_EquivTypes);
372+
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
373+
DECL_NPY_API(PyArray_Squeeze);
374+
DECL_NPY_API(PyArray_SetBaseObject);
375+
// - Dtypes
376+
DECL_NPY_API(PyGenericArrType_Type);
377+
DECL_NPY_API(PyArray_RegisterDataType);
378+
DECL_NPY_API(PyArray_InitArrFuncs);
379+
DECL_NPY_API(PyArray_RegisterCastFunc);
380+
DECL_NPY_API(PyArray_RegisterCanCast);
381+
}
382+
// umath -> _UFUNC_API
383+
{
384+
module umath = module::import("numpy.core.umath");
385+
auto api_ptr = get_api_ptr(umath.attr("_UFUNC_API"));
386+
DECL_NPY_API(PyUFunc_FromFuncAndData);
387+
DECL_NPY_API(PyUFunc_RegisterLoopForType);
388+
DECL_NPY_API(PyUFunc_ReplaceLoopBySignature);
389+
}
219390
#undef DECL_NPY_API
220391
return api;
221392
}
@@ -465,6 +636,11 @@ class dtype : public object {
465636
return detail::array_descriptor_proxy(m_ptr)->kind;
466637
}
467638

639+
/// Type index for builtin or user-registered dtypes.
640+
int num() const {
641+
return detail::array_descriptor_proxy(m_ptr)->type_num;
642+
}
643+
468644
private:
469645
static object _dtype_from_pep3118() {
470646
static PyObject *obj = module::import("numpy.core._internal")
@@ -1049,6 +1225,26 @@ template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>
10491225
static pybind11::dtype dtype() { return base_descr::dtype(); }
10501226
};
10511227

1228+
template <>
1229+
struct npy_format_descriptor<object> {
1230+
static pybind11::dtype dtype() {
1231+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(npy_api::NPY_OBJECT_))
1232+
return reinterpret_borrow<pybind11::dtype>(ptr);
1233+
pybind11_fail("Unsupported buffer format!");
1234+
}
1235+
};
1236+
1237+
template <>
1238+
struct npy_format_descriptor<void> {
1239+
static constexpr auto name = detail::_<void>();
1240+
static pybind11::dtype dtype() {
1241+
if (auto ptr = detail::npy_api::get().PyArray_DescrFromType_(
1242+
detail::npy_api::constants::NPY_VOID_))
1243+
return reinterpret_borrow<pybind11::dtype>(ptr);
1244+
pybind11_fail("Unsupported buffer format!");
1245+
}
1246+
};
1247+
10521248
struct field_descriptor {
10531249
const char *name;
10541250
ssize_t offset;

0 commit comments

Comments
 (0)