Skip to content

Commit f662de8

Browse files
numpy: Add ufunc and custom user dtype API. Ensure py::object and void can be found via dtype
1 parent 2fd0b89 commit f662de8

File tree

1 file changed

+218
-27
lines changed

1 file changed

+218
-27
lines changed

include/pybind11/numpy.h

Lines changed: 218 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,105 @@ struct PyVoidScalarObject_Proxy {
7575
PyObject *base;
7676
};
7777

78+
// UFuncs.
79+
using npy_intp = Py_intptr_t;
80+
81+
typedef void (*PyUFuncGenericFunction)(
82+
char **args, npy_intp *dimensions, npy_intp *strides, void *innerloopdata);
83+
84+
typedef void (PyArray_VectorUnaryFunc)(
85+
void* from_, void* to_, npy_intp n, void* fromarr, void* toarr);
86+
87+
typedef struct {
88+
PyObject_HEAD
89+
int nin;
90+
int nout;
91+
int nargs;
92+
int identity;
93+
PyUFuncGenericFunction *functions;
94+
void **data;
95+
int ntypes;
96+
int reserved1;
97+
const char *name;
98+
char *types;
99+
const char *doc;
100+
void *ptr;
101+
PyObject *obj;
102+
PyObject *userloops;
103+
uint32_t *op_flags;
104+
uint32_t *iter_flags;
105+
} PyUFuncObject;
106+
107+
// Manually defined :(
108+
constexpr int NPY_NTYPES_ABI_COMPATIBLE = 21;
109+
constexpr int NPY_NSORTS = 3;
110+
111+
// TODO(eric.cousineau): Fill this out as needed for type safety.
112+
// TODO(eric.cousineau): Do not define these if NPY headers are present (for debugging).
113+
using PyArray_GetItemFunc = void;
114+
using PyArray_SetItemFunc = void;
115+
using PyArray_CopySwapNFunc = void;
116+
using PyArray_CopySwapFunc = void;
117+
using PyArray_CompareFunc = void;
118+
using PyArray_ArgFunc = void;
119+
using PyArray_DotFunc = void;
120+
using PyArray_ScanFunc = void;
121+
using PyArray_FromStrFunc = void;
122+
using PyArray_NonzeroFunc = void;
123+
using PyArray_FillFunc = void;
124+
using PyArray_FillWithScalarFunc = void;
125+
using PyArray_SortFunc = void;
126+
using PyArray_ArgSortFunc = void;
127+
using PyArray_ScalarKindFunc = void;
128+
using PyArray_FastClipFunc = void;
129+
using PyArray_FastPutmaskFunc = void;
130+
using PyArray_FastTakeFunc = void;
131+
using PyArray_ArgFunc = void;
132+
133+
typedef struct {
134+
PyArray_VectorUnaryFunc *cast[NPY_NTYPES_ABI_COMPATIBLE];
135+
PyArray_GetItemFunc *getitem;
136+
PyArray_SetItemFunc *setitem;
137+
PyArray_CopySwapNFunc *copyswapn;
138+
PyArray_CopySwapFunc *copyswap;
139+
PyArray_CompareFunc *compare;
140+
PyArray_ArgFunc *argmax;
141+
PyArray_DotFunc *dotfunc;
142+
PyArray_ScanFunc *scanfunc;
143+
PyArray_FromStrFunc *fromstr;
144+
PyArray_NonzeroFunc *nonzero;
145+
PyArray_FillFunc *fill;
146+
PyArray_FillWithScalarFunc *fillwithscalar;
147+
PyArray_SortFunc *sort[NPY_NSORTS];
148+
PyArray_ArgSortFunc *argsort[NPY_NSORTS];
149+
PyObject *castdict;
150+
PyArray_ScalarKindFunc *scalarkind;
151+
int **cancastscalarkindto;
152+
int *cancastto;
153+
PyArray_FastClipFunc *fastclip;
154+
PyArray_FastPutmaskFunc *fastputmask;
155+
PyArray_FastTakeFunc *fasttake;
156+
PyArray_ArgFunc *argmin;
157+
} PyArray_ArrFuncs;
158+
159+
using PyArray_ArrayDescr = void;
160+
161+
typedef struct {
162+
PyObject_HEAD
163+
PyTypeObject *typeobj;
164+
char kind;
165+
char type;
166+
char byteorder;
167+
char unused;
168+
int flags;
169+
int type_num;
170+
int elsize;
171+
int alignment;
172+
PyArray_ArrayDescr *subarray;
173+
PyObject *fields;
174+
PyArray_ArrFuncs *f;
175+
} PyArray_Descr;
176+
78177
struct numpy_type_info {
79178
PyObject* dtype_ptr;
80179
std::string format_str;
@@ -109,14 +208,16 @@ inline numpy_internals& get_numpy_internals() {
109208
}
110209

111210
struct npy_api {
112-
enum constants {
211+
enum constants : int {
212+
// Array properties
113213
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
114214
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
115215
NPY_ARRAY_OWNDATA_ = 0x0004,
116216
NPY_ARRAY_FORCECAST_ = 0x0010,
117217
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
118218
NPY_ARRAY_ALIGNED_ = 0x0100,
119219
NPY_ARRAY_WRITEABLE_ = 0x0400,
220+
// Dtypes
120221
NPY_BOOL_ = 0,
121222
NPY_BYTE_, NPY_UBYTE_,
122223
NPY_SHORT_, NPY_USHORT_,
@@ -126,9 +227,27 @@ struct npy_api {
126227
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
127228
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
128229
NPY_OBJECT_ = 17,
129-
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
230+
NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
231+
NPY_USERDEF_ = 256,
232+
// Descriptor flags
233+
NPY_NEEDS_INIT_ = 0x08,
234+
NPY_NEEDS_PYAPI_ = 0x10,
235+
NPY_USE_GETITEM_ = 0x20,
236+
NPY_USE_SETITEM_ = 0x40,
237+
// UFunc
238+
PyUFunc_None_ = -1,
130239
};
131240

241+
typedef enum {
242+
NPY_NOSCALAR_ = -1,
243+
NPY_BOOL_SCALAR_,
244+
NPY_INTPOS_SCALAR_,
245+
NPY_INTNEG_SCALAR_,
246+
NPY_FLOAT_SCALAR_,
247+
NPY_COMPLEX_SCALAR_,
248+
NPY_OBJECT_SCALAR_
249+
} NPY_SCALARKIND;
250+
132251
typedef struct {
133252
Py_intptr_t *ptr;
134253
int len;
@@ -146,6 +265,7 @@ struct npy_api {
146265
return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
147266
}
148267

268+
// Multiarray.
149269
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
150270
PyObject *(*PyArray_DescrFromType_)(int);
151271
PyObject *(*PyArray_NewFromDescr_)
@@ -166,8 +286,29 @@ struct npy_api {
166286
PyObject *(*PyArray_Squeeze_)(PyObject *);
167287
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
168288
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
289+
290+
// - Dtypes
291+
PyTypeObject* PyGenericArrType_Type_;
292+
int (*PyArray_RegisterDataType_)(PyArray_Descr* dtype);
293+
int (*PyArray_RegisterCastFunc_)(PyArray_Descr* descr, int totype, PyArray_VectorUnaryFunc* castfunc);
294+
int (*PyArray_RegisterCanCast_)(PyArray_Descr* descr, int totype, NPY_SCALARKIND scalar);
295+
void (*PyArray_InitArrFuncs_)(PyArray_ArrFuncs *f);
296+
297+
// UFuncs.
298+
PyObject* (*PyUFunc_FromFuncAndData_)(
299+
PyUFuncGenericFunction* func, void** data, char* types, int ntypes,
300+
int nin, int nout, int identity, char* name, char* doc, int unused);
301+
302+
int (*PyUFunc_RegisterLoopForType_)(
303+
PyUFuncObject* ufunc, int usertype, PyUFuncGenericFunction function, int* arg_types, void* data);
304+
305+
int (*PyUFunc_ReplaceLoopBySignature_)(
306+
PyUFuncObject *func, PyUFuncGenericFunction newfunc,
307+
int *signature, PyUFuncGenericFunction *oldfunc);
169308
private:
309+
// TODO(eric.cousineau): Rename to `items` or something, since this now applies to types.
170310
enum functions {
311+
// multiarray
171312
API_PyArray_GetNDArrayCFeatureVersion = 211,
172313
API_PyArray_Type = 2,
173314
API_PyArrayDescr_Type = 3,
@@ -184,38 +325,68 @@ struct npy_api {
184325
API_PyArray_EquivTypes = 182,
185326
API_PyArray_GetArrayParamsFromObject = 278,
186327
API_PyArray_Squeeze = 136,
187-
API_PyArray_SetBaseObject = 282
328+
API_PyArray_SetBaseObject = 282,
329+
// - DTypes
330+
API_PyGenericArrType_Type = 10,
331+
API_PyArray_RegisterDataType = 192,
332+
API_PyArray_RegisterCastFunc = 193,
333+
API_PyArray_RegisterCanCast = 194,
334+
API_PyArray_InitArrFuncs = 195,
335+
// umath
336+
API_PyUFunc_FromFuncAndData = 1,
337+
API_PyUFunc_RegisterLoopForType = 2,
338+
API_PyUFunc_ReplaceLoopBySignature = 30,
188339
};
189340

190-
static npy_api lookup() {
191-
module m = module::import("numpy.core.multiarray");
192-
auto c = m.attr("_ARRAY_API");
341+
static void** get_api_ptr(object c) {
193342
#if PY_MAJOR_VERSION >= 3
194-
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
343+
return (void **) PyCapsule_GetPointer(c.ptr(), NULL);
195344
#else
196-
void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
345+
return (void **) PyCObject_AsVoidPtr(c.ptr());
197346
#endif
347+
}
348+
349+
static npy_api lookup() {
198350
npy_api api;
199351
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
200-
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
201-
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
202-
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
203-
DECL_NPY_API(PyArray_Type);
204-
DECL_NPY_API(PyVoidArrType_Type);
205-
DECL_NPY_API(PyArrayDescr_Type);
206-
DECL_NPY_API(PyArray_DescrFromType);
207-
DECL_NPY_API(PyArray_DescrFromScalar);
208-
DECL_NPY_API(PyArray_FromAny);
209-
DECL_NPY_API(PyArray_Resize);
210-
DECL_NPY_API(PyArray_CopyInto);
211-
DECL_NPY_API(PyArray_NewCopy);
212-
DECL_NPY_API(PyArray_NewFromDescr);
213-
DECL_NPY_API(PyArray_DescrNewFromType);
214-
DECL_NPY_API(PyArray_DescrConverter);
215-
DECL_NPY_API(PyArray_EquivTypes);
216-
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
217-
DECL_NPY_API(PyArray_Squeeze);
218-
DECL_NPY_API(PyArray_SetBaseObject);
352+
// multiarray -> _ARRAY_API
353+
{
354+
module multiarray = module::import("numpy.core.multiarray");
355+
auto api_ptr = get_api_ptr(multiarray.attr("_ARRAY_API"));
356+
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
357+
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7)
358+
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
359+
DECL_NPY_API(PyArray_Type);
360+
DECL_NPY_API(PyVoidArrType_Type);
361+
DECL_NPY_API(PyArrayDescr_Type);
362+
DECL_NPY_API(PyArray_DescrFromType);
363+
DECL_NPY_API(PyArray_DescrFromScalar);
364+
DECL_NPY_API(PyArray_FromAny);
365+
DECL_NPY_API(PyArray_Resize);
366+
DECL_NPY_API(PyArray_CopyInto);
367+
DECL_NPY_API(PyArray_NewCopy);
368+
DECL_NPY_API(PyArray_NewFromDescr);
369+
DECL_NPY_API(PyArray_DescrNewFromType);
370+
DECL_NPY_API(PyArray_DescrConverter);
371+
DECL_NPY_API(PyArray_EquivTypes);
372+
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
373+
DECL_NPY_API(PyArray_Squeeze);
374+
DECL_NPY_API(PyArray_SetBaseObject);
375+
// - Dtypes
376+
DECL_NPY_API(PyGenericArrType_Type);
377+
DECL_NPY_API(PyArray_RegisterDataType);
378+
DECL_NPY_API(PyArray_InitArrFuncs);
379+
DECL_NPY_API(PyArray_RegisterCastFunc);
380+
DECL_NPY_API(PyArray_RegisterCanCast);
381+
}
382+
// umath -> _UFUNC_API
383+
{
384+
module umath = module::import("numpy.core.umath");
385+
auto api_ptr = get_api_ptr(umath.attr("_UFUNC_API"));
386+
DECL_NPY_API(PyUFunc_FromFuncAndData);
387+
DECL_NPY_API(PyUFunc_RegisterLoopForType);
388+
DECL_NPY_API(PyUFunc_ReplaceLoopBySignature);
389+
}
219390
#undef DECL_NPY_API
220391
return api;
221392
}
@@ -1053,6 +1224,26 @@ template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>
10531224
static pybind11::dtype dtype() { return base_descr::dtype(); }
10541225
};
10551226

1227+
template <>
1228+
struct npy_format_descriptor<object> {
1229+
static pybind11::dtype dtype() {
1230+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(npy_api::NPY_OBJECT_))
1231+
return reinterpret_borrow<pybind11::dtype>(ptr);
1232+
pybind11_fail("Unsupported buffer format!");
1233+
}
1234+
};
1235+
1236+
template <>
1237+
struct npy_format_descriptor<void> {
1238+
static constexpr auto name = detail::_<void>();
1239+
static pybind11::dtype dtype() {
1240+
if (auto ptr = detail::npy_api::get().PyArray_DescrFromType_(
1241+
detail::npy_api::constants::NPY_VOID_))
1242+
return reinterpret_borrow<pybind11::dtype>(ptr);
1243+
pybind11_fail("Unsupported buffer format!");
1244+
}
1245+
};
1246+
10561247
struct field_descriptor {
10571248
const char *name;
10581249
ssize_t offset;

0 commit comments

Comments
 (0)