@@ -75,6 +75,105 @@ struct PyVoidScalarObject_Proxy {
75
75
PyObject *base;
76
76
};
77
77
78
+ // UFuncs.
79
+ using npy_intp = Py_intptr_t;
80
+
81
+ typedef void (*PyUFuncGenericFunction)(
82
+ char **args, npy_intp *dimensions, npy_intp *strides, void *innerloopdata);
83
+
84
+ typedef void (PyArray_VectorUnaryFunc)(
85
+ void * from_, void * to_, npy_intp n, void * fromarr, void * toarr);
86
+
87
+ typedef struct {
88
+ PyObject_HEAD
89
+ int nin;
90
+ int nout;
91
+ int nargs;
92
+ int identity;
93
+ PyUFuncGenericFunction *functions;
94
+ void **data;
95
+ int ntypes;
96
+ int reserved1;
97
+ const char *name;
98
+ char *types;
99
+ const char *doc;
100
+ void *ptr;
101
+ PyObject *obj;
102
+ PyObject *userloops;
103
+ uint32_t *op_flags;
104
+ uint32_t *iter_flags;
105
+ } PyUFuncObject;
106
+
107
+ // Manually defined :(
108
+ constexpr int NPY_NTYPES_ABI_COMPATIBLE = 21 ;
109
+ constexpr int NPY_NSORTS = 3 ;
110
+
111
+ // TODO(eric.cousineau): Fill this out as needed for type safety.
112
+ // TODO(eric.cousineau): Do not define these if NPY headers are present (for debugging).
113
+ using PyArray_GetItemFunc = void ;
114
+ using PyArray_SetItemFunc = void ;
115
+ using PyArray_CopySwapNFunc = void ;
116
+ using PyArray_CopySwapFunc = void ;
117
+ using PyArray_CompareFunc = void ;
118
+ using PyArray_ArgFunc = void ;
119
+ using PyArray_DotFunc = void ;
120
+ using PyArray_ScanFunc = void ;
121
+ using PyArray_FromStrFunc = void ;
122
+ using PyArray_NonzeroFunc = void ;
123
+ using PyArray_FillFunc = void ;
124
+ using PyArray_FillWithScalarFunc = void ;
125
+ using PyArray_SortFunc = void ;
126
+ using PyArray_ArgSortFunc = void ;
127
+ using PyArray_ScalarKindFunc = void ;
128
+ using PyArray_FastClipFunc = void ;
129
+ using PyArray_FastPutmaskFunc = void ;
130
+ using PyArray_FastTakeFunc = void ;
131
+ using PyArray_ArgFunc = void ;
132
+
133
+ typedef struct {
134
+ PyArray_VectorUnaryFunc *cast[NPY_NTYPES_ABI_COMPATIBLE];
135
+ PyArray_GetItemFunc *getitem;
136
+ PyArray_SetItemFunc *setitem;
137
+ PyArray_CopySwapNFunc *copyswapn;
138
+ PyArray_CopySwapFunc *copyswap;
139
+ PyArray_CompareFunc *compare;
140
+ PyArray_ArgFunc *argmax;
141
+ PyArray_DotFunc *dotfunc;
142
+ PyArray_ScanFunc *scanfunc;
143
+ PyArray_FromStrFunc *fromstr;
144
+ PyArray_NonzeroFunc *nonzero;
145
+ PyArray_FillFunc *fill;
146
+ PyArray_FillWithScalarFunc *fillwithscalar;
147
+ PyArray_SortFunc *sort[NPY_NSORTS];
148
+ PyArray_ArgSortFunc *argsort[NPY_NSORTS];
149
+ PyObject *castdict;
150
+ PyArray_ScalarKindFunc *scalarkind;
151
+ int **cancastscalarkindto;
152
+ int *cancastto;
153
+ PyArray_FastClipFunc *fastclip;
154
+ PyArray_FastPutmaskFunc *fastputmask;
155
+ PyArray_FastTakeFunc *fasttake;
156
+ PyArray_ArgFunc *argmin;
157
+ } PyArray_ArrFuncs;
158
+
159
+ using PyArray_ArrayDescr = void ;
160
+
161
+ typedef struct {
162
+ PyObject_HEAD
163
+ PyTypeObject *typeobj;
164
+ char kind;
165
+ char type;
166
+ char byteorder;
167
+ char unused;
168
+ int flags;
169
+ int type_num;
170
+ int elsize;
171
+ int alignment;
172
+ PyArray_ArrayDescr *subarray;
173
+ PyObject *fields;
174
+ PyArray_ArrFuncs *f;
175
+ } PyArray_Descr;
176
+
78
177
struct numpy_type_info {
79
178
PyObject* dtype_ptr;
80
179
std::string format_str;
@@ -109,14 +208,16 @@ inline numpy_internals& get_numpy_internals() {
109
208
}
110
209
111
210
struct npy_api {
112
- enum constants {
211
+ enum constants : int {
212
+ // Array properties
113
213
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001 ,
114
214
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002 ,
115
215
NPY_ARRAY_OWNDATA_ = 0x0004 ,
116
216
NPY_ARRAY_FORCECAST_ = 0x0010 ,
117
217
NPY_ARRAY_ENSUREARRAY_ = 0x0040 ,
118
218
NPY_ARRAY_ALIGNED_ = 0x0100 ,
119
219
NPY_ARRAY_WRITEABLE_ = 0x0400 ,
220
+ // Dtypes
120
221
NPY_BOOL_ = 0 ,
121
222
NPY_BYTE_, NPY_UBYTE_,
122
223
NPY_SHORT_, NPY_USHORT_,
@@ -126,9 +227,27 @@ struct npy_api {
126
227
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
127
228
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
128
229
NPY_OBJECT_ = 17 ,
129
- NPY_STRING_, NPY_UNICODE_, NPY_VOID_
230
+ NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
231
+ NPY_USERDEF_ = 256 ,
232
+ // Descriptor flags
233
+ NPY_NEEDS_INIT_ = 0x08 ,
234
+ NPY_NEEDS_PYAPI_ = 0x10 ,
235
+ NPY_USE_GETITEM_ = 0x20 ,
236
+ NPY_USE_SETITEM_ = 0x40 ,
237
+ // UFunc
238
+ PyUFunc_None_ = -1 ,
130
239
};
131
240
241
+ typedef enum {
242
+ NPY_NOSCALAR_ = -1 ,
243
+ NPY_BOOL_SCALAR_,
244
+ NPY_INTPOS_SCALAR_,
245
+ NPY_INTNEG_SCALAR_,
246
+ NPY_FLOAT_SCALAR_,
247
+ NPY_COMPLEX_SCALAR_,
248
+ NPY_OBJECT_SCALAR_
249
+ } NPY_SCALARKIND;
250
+
132
251
typedef struct {
133
252
Py_intptr_t *ptr;
134
253
int len;
@@ -146,6 +265,7 @@ struct npy_api {
146
265
return (bool ) PyObject_TypeCheck (obj, PyArrayDescr_Type_);
147
266
}
148
267
268
+ // Multiarray.
149
269
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
150
270
PyObject *(*PyArray_DescrFromType_)(int );
151
271
PyObject *(*PyArray_NewFromDescr_)
@@ -166,8 +286,29 @@ struct npy_api {
166
286
PyObject *(*PyArray_Squeeze_)(PyObject *);
167
287
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
168
288
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int , int );
289
+
290
+ // - Dtypes
291
+ PyTypeObject* PyGenericArrType_Type_;
292
+ int (*PyArray_RegisterDataType_)(PyArray_Descr* dtype);
293
+ int (*PyArray_RegisterCastFunc_)(PyArray_Descr* descr, int totype, PyArray_VectorUnaryFunc* castfunc);
294
+ int (*PyArray_RegisterCanCast_)(PyArray_Descr* descr, int totype, NPY_SCALARKIND scalar);
295
+ void (*PyArray_InitArrFuncs_)(PyArray_ArrFuncs *f);
296
+
297
+ // UFuncs.
298
+ PyObject* (*PyUFunc_FromFuncAndData_)(
299
+ PyUFuncGenericFunction* func, void ** data, char * types, int ntypes,
300
+ int nin, int nout, int identity, char * name, char * doc, int unused);
301
+
302
+ int (*PyUFunc_RegisterLoopForType_)(
303
+ PyUFuncObject* ufunc, int usertype, PyUFuncGenericFunction function, int * arg_types, void * data);
304
+
305
+ int (*PyUFunc_ReplaceLoopBySignature_)(
306
+ PyUFuncObject *func, PyUFuncGenericFunction newfunc,
307
+ int *signature, PyUFuncGenericFunction *oldfunc);
169
308
private:
309
+ // TODO(eric.cousineau): Rename to `items` or something, since this now applies to types.
170
310
enum functions {
311
+ // multiarray
171
312
API_PyArray_GetNDArrayCFeatureVersion = 211 ,
172
313
API_PyArray_Type = 2 ,
173
314
API_PyArrayDescr_Type = 3 ,
@@ -184,38 +325,68 @@ struct npy_api {
184
325
API_PyArray_EquivTypes = 182 ,
185
326
API_PyArray_GetArrayParamsFromObject = 278 ,
186
327
API_PyArray_Squeeze = 136 ,
187
- API_PyArray_SetBaseObject = 282
328
+ API_PyArray_SetBaseObject = 282 ,
329
+ // - DTypes
330
+ API_PyGenericArrType_Type = 10 ,
331
+ API_PyArray_RegisterDataType = 192 ,
332
+ API_PyArray_RegisterCastFunc = 193 ,
333
+ API_PyArray_RegisterCanCast = 194 ,
334
+ API_PyArray_InitArrFuncs = 195 ,
335
+ // umath
336
+ API_PyUFunc_FromFuncAndData = 1 ,
337
+ API_PyUFunc_RegisterLoopForType = 2 ,
338
+ API_PyUFunc_ReplaceLoopBySignature = 30 ,
188
339
};
189
340
190
- static npy_api lookup () {
191
- module m = module::import (" numpy.core.multiarray" );
192
- auto c = m.attr (" _ARRAY_API" );
341
+ static void ** get_api_ptr (object c) {
193
342
#if PY_MAJOR_VERSION >= 3
194
- void **api_ptr = (void **) PyCapsule_GetPointer (c.ptr (), NULL );
343
+ return (void **) PyCapsule_GetPointer (c.ptr (), NULL );
195
344
#else
196
- void **api_ptr = (void **) PyCObject_AsVoidPtr (c.ptr ());
345
+ return (void **) PyCObject_AsVoidPtr (c.ptr ());
197
346
#endif
347
+ }
348
+
349
+ static npy_api lookup () {
198
350
npy_api api;
199
351
#define DECL_NPY_API (Func ) api.Func##_ = (decltype (api.Func##_)) api_ptr[API_##Func];
200
- DECL_NPY_API (PyArray_GetNDArrayCFeatureVersion);
201
- if (api.PyArray_GetNDArrayCFeatureVersion_ () < 0x7 )
202
- pybind11_fail (" pybind11 numpy support requires numpy >= 1.7.0" );
203
- DECL_NPY_API (PyArray_Type);
204
- DECL_NPY_API (PyVoidArrType_Type);
205
- DECL_NPY_API (PyArrayDescr_Type);
206
- DECL_NPY_API (PyArray_DescrFromType);
207
- DECL_NPY_API (PyArray_DescrFromScalar);
208
- DECL_NPY_API (PyArray_FromAny);
209
- DECL_NPY_API (PyArray_Resize);
210
- DECL_NPY_API (PyArray_CopyInto);
211
- DECL_NPY_API (PyArray_NewCopy);
212
- DECL_NPY_API (PyArray_NewFromDescr);
213
- DECL_NPY_API (PyArray_DescrNewFromType);
214
- DECL_NPY_API (PyArray_DescrConverter);
215
- DECL_NPY_API (PyArray_EquivTypes);
216
- DECL_NPY_API (PyArray_GetArrayParamsFromObject);
217
- DECL_NPY_API (PyArray_Squeeze);
218
- DECL_NPY_API (PyArray_SetBaseObject);
352
+ // multiarray -> _ARRAY_API
353
+ {
354
+ module multiarray = module::import (" numpy.core.multiarray" );
355
+ auto api_ptr = get_api_ptr (multiarray.attr (" _ARRAY_API" ));
356
+ DECL_NPY_API (PyArray_GetNDArrayCFeatureVersion);
357
+ if (api.PyArray_GetNDArrayCFeatureVersion_ () < 0x7 )
358
+ pybind11_fail (" pybind11 numpy support requires numpy >= 1.7.0" );
359
+ DECL_NPY_API (PyArray_Type);
360
+ DECL_NPY_API (PyVoidArrType_Type);
361
+ DECL_NPY_API (PyArrayDescr_Type);
362
+ DECL_NPY_API (PyArray_DescrFromType);
363
+ DECL_NPY_API (PyArray_DescrFromScalar);
364
+ DECL_NPY_API (PyArray_FromAny);
365
+ DECL_NPY_API (PyArray_Resize);
366
+ DECL_NPY_API (PyArray_CopyInto);
367
+ DECL_NPY_API (PyArray_NewCopy);
368
+ DECL_NPY_API (PyArray_NewFromDescr);
369
+ DECL_NPY_API (PyArray_DescrNewFromType);
370
+ DECL_NPY_API (PyArray_DescrConverter);
371
+ DECL_NPY_API (PyArray_EquivTypes);
372
+ DECL_NPY_API (PyArray_GetArrayParamsFromObject);
373
+ DECL_NPY_API (PyArray_Squeeze);
374
+ DECL_NPY_API (PyArray_SetBaseObject);
375
+ // - Dtypes
376
+ DECL_NPY_API (PyGenericArrType_Type);
377
+ DECL_NPY_API (PyArray_RegisterDataType);
378
+ DECL_NPY_API (PyArray_InitArrFuncs);
379
+ DECL_NPY_API (PyArray_RegisterCastFunc);
380
+ DECL_NPY_API (PyArray_RegisterCanCast);
381
+ }
382
+ // umath -> _UFUNC_API
383
+ {
384
+ module umath = module::import (" numpy.core.umath" );
385
+ auto api_ptr = get_api_ptr (umath.attr (" _UFUNC_API" ));
386
+ DECL_NPY_API (PyUFunc_FromFuncAndData);
387
+ DECL_NPY_API (PyUFunc_RegisterLoopForType);
388
+ DECL_NPY_API (PyUFunc_ReplaceLoopBySignature);
389
+ }
219
390
#undef DECL_NPY_API
220
391
return api;
221
392
}
@@ -465,6 +636,11 @@ class dtype : public object {
465
636
return detail::array_descriptor_proxy (m_ptr)->kind ;
466
637
}
467
638
639
+ // / Type index for builtin or user-registered dtypes.
640
+ int num () const {
641
+ return detail::array_descriptor_proxy (m_ptr)->type_num ;
642
+ }
643
+
468
644
private:
469
645
static object _dtype_from_pep3118 () {
470
646
static PyObject *obj = module::import (" numpy.core._internal" )
@@ -1049,6 +1225,26 @@ template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>
1049
1225
static pybind11::dtype dtype () { return base_descr::dtype (); }
1050
1226
};
1051
1227
1228
+ template <>
1229
+ struct npy_format_descriptor <object> {
1230
+ static pybind11::dtype dtype () {
1231
+ if (auto ptr = npy_api::get ().PyArray_DescrFromType_ (npy_api::NPY_OBJECT_))
1232
+ return reinterpret_borrow<pybind11::dtype>(ptr);
1233
+ pybind11_fail (" Unsupported buffer format!" );
1234
+ }
1235
+ };
1236
+
1237
+ template <>
1238
+ struct npy_format_descriptor <void > {
1239
+ static constexpr auto name = detail::_<void >();
1240
+ static pybind11::dtype dtype () {
1241
+ if (auto ptr = detail::npy_api::get ().PyArray_DescrFromType_ (
1242
+ detail::npy_api::constants::NPY_VOID_))
1243
+ return reinterpret_borrow<pybind11::dtype>(ptr);
1244
+ pybind11_fail (" Unsupported buffer format!" );
1245
+ }
1246
+ };
1247
+
1052
1248
struct field_descriptor {
1053
1249
const char *name;
1054
1250
ssize_t offset;
0 commit comments