Skip to content

Commit 31fe60a

Browse files
numpy_dtype_user: Allow custom user dtypes to be defined
1 parent f662de8 commit 31fe60a

File tree

5 files changed

+1132
-0
lines changed

5 files changed

+1132
-0
lines changed

include/pybind11/detail/numpy_ufunc.h

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/*
2+
pybind11/detail/numpy_ufunc.h: Simple glue for Python UFuncs
3+
4+
Copyright (c) 2018 Eric Cousineau <[email protected]>
5+
6+
All rights reserved. Use of this source code is governed by a
7+
BSD-style license that can be found in the LICENSE file.
8+
*/
9+
10+
#pragma once
11+
12+
#include "../numpy.h"
13+
#include "inference.h"
14+
15+
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
16+
NAMESPACE_BEGIN(detail)
17+
18+
// Utilities
19+
20+
// Builtins registered using numpy/build/{...}/numpy/core/include/numpy/__umath_generated.c
21+
22+
template <typename... Args>
23+
struct ufunc_ptr {
24+
PyUFuncGenericFunction func{};
25+
void* data{};
26+
};
27+
28+
// Unary ufunc.
29+
template <typename Arg0, typename Out, typename Func>
30+
auto ufunc_to_ptr(Func func, type_pack<Arg0, Out>) {
31+
auto ufunc = [](
32+
char** args, npy_intp* dimensions, npy_intp* steps, void* data) {
33+
Func& func = *(Func*)data;
34+
int step_0 = steps[0];
35+
int step_out = steps[1];
36+
int n = *dimensions;
37+
char *in_0 = args[0], *out = args[1];
38+
for (int k = 0; k < n; k++) {
39+
// TODO(eric.cousineau): Support pointers being changed.
40+
*(Out*)out = func(*(Arg0*)in_0);
41+
in_0 += step_0;
42+
out += step_out;
43+
}
44+
};
45+
// N.B. `new Func(...)` will never be destroyed.
46+
return ufunc_ptr<Arg0, Out>{ufunc, new Func(func)};
47+
}
48+
49+
// Binary ufunc.
50+
template <typename Arg0, typename Arg1, typename Out, typename Func = void>
51+
auto ufunc_to_ptr(Func func, type_pack<Arg0, Arg1, Out>) {
52+
auto ufunc = [](char** args, npy_intp* dimensions, npy_intp* steps, void* data) {
53+
Func& func = *(Func*)data;
54+
int step_0 = steps[0];
55+
int step_1 = steps[1];
56+
int step_out = steps[2];
57+
int n = *dimensions;
58+
char *in_0 = args[0], *in_1 = args[1], *out = args[2];
59+
for (int k = 0; k < n; k++) {
60+
// TODO(eric.cousineau): Support pointers being fed in.
61+
*(Out*)out = func(*(Arg0*)in_0, *(Arg1*)in_1);
62+
in_0 += step_0;
63+
in_1 += step_1;
64+
out += step_out;
65+
}
66+
};
67+
// N.B. `new Func(...)` will never be destroyed.
68+
return ufunc_ptr<Arg0, Arg1, Out>{ufunc, new Func(func)};
69+
}
70+
71+
// Generic dispatch.
72+
template <typename Func>
73+
auto ufunc_to_ptr(Func func) {
74+
auto info = detail::function_inference::run(func);
75+
using Info = decltype(info);
76+
auto type_args = type_pack_apply<std::decay_t>(
77+
type_pack_concat(
78+
typename Info::Args{},
79+
type_pack<typename Info::Return>{}));
80+
return ufunc_to_ptr(func, type_args);
81+
}
82+
83+
template <typename From, typename To, typename Func>
84+
void ufunc_register_cast(
85+
Func&& func, bool allow_coercion, type_pack<From, To> = {}) {
86+
static auto cast_lambda = detail::function_inference::run(func).func;
87+
auto cast_func = +[](
88+
void* from_, void* to_, npy_intp n,
89+
void* fromarr, void* toarr) {
90+
const From* from = (From*)from_;
91+
To* to = (To*)to_;
92+
for (npy_intp i = 0; i < n; i++)
93+
to[i] = cast_lambda(from[i]);
94+
};
95+
auto& api = npy_api::get();
96+
auto from = npy_format_descriptor<From>::dtype();
97+
int to_num = npy_format_descriptor<To>::dtype().num();
98+
auto from_raw = (PyArray_Descr*)from.ptr();
99+
if (api.PyArray_RegisterCastFunc_(from_raw, to_num, cast_func) < 0)
100+
pybind11_fail("ufunc: Cannot register cast");
101+
if (allow_coercion) {
102+
if (api.PyArray_RegisterCanCast_(
103+
from_raw, to_num, npy_api::NPY_NOSCALAR_) < 0)
104+
pybind11_fail(
105+
"ufunc: Cannot register implicit / coercion cast capability");
106+
}
107+
}
108+
109+
NAMESPACE_END(detail)
110+
111+
class ufunc : public object {
112+
public:
113+
ufunc(object ptr) : object(ptr) {
114+
// TODO(eric.cousineau): Check type.
115+
}
116+
117+
ufunc(detail::PyUFuncObject* ptr)
118+
: object(reinterpret_borrow<object>((PyObject*)ptr))
119+
{}
120+
121+
ufunc(handle scope, const char* name) : scope_{scope}, name_{name} {}
122+
123+
// Gets a NumPy UFunc by name.
124+
static ufunc get_builtin(const char* name) {
125+
module numpy = module::import("numpy");
126+
return ufunc(numpy.attr(name));
127+
}
128+
129+
template <typename Type, typename Func>
130+
ufunc& def_loop(Func func) {
131+
do_register<Type>(detail::ufunc_to_ptr(func));
132+
return *this;
133+
}
134+
135+
detail::PyUFuncObject* ptr() const {
136+
return (detail::PyUFuncObject*)self().ptr();
137+
}
138+
139+
private:
140+
object& self() { return *this; }
141+
const object& self() const { return *this; }
142+
143+
// Registers a function pointer as a UFunc, mapping types to dtype nums.
144+
template <typename Type, typename ... Args>
145+
void do_register(detail::ufunc_ptr<Args...> user) {
146+
constexpr int N = sizeof...(Args);
147+
constexpr int nin = N - 1;
148+
constexpr int nout = 1;
149+
int dtype = dtype::of<Type>().num();
150+
int dtype_args[] = {dtype::of<Args>().num()...};
151+
// Determine if we need to make a new ufunc.
152+
using constants = detail::npy_api::constants;
153+
auto& api = detail::npy_api::get();
154+
if (!self()) {
155+
if (!name_)
156+
pybind11_fail("dtype: unspecified name");
157+
// TODO(eric.cousineau): Fix unfreed memory with `name`.
158+
auto leak = new std::string(name_);
159+
// The following dummy stuff is to allow monkey-patching existing ufuncs.
160+
// This is a bit sketchy, as calling the wrong thing may cause a segfault.
161+
// TODO(eric.cousineau): Figure out how to more elegantly specify preallocation...
162+
// Preallocate to allow replacement?
163+
constexpr int ntypes = 4;
164+
static char tinker_types[ntypes] = {
165+
constants::NPY_BOOL_,
166+
constants::NPY_INT_,
167+
constants::NPY_FLOAT_,
168+
constants::NPY_DOUBLE_,
169+
};
170+
auto dummy_funcs = new detail::PyUFuncGenericFunction[ntypes];
171+
auto dummy_data = new void*[ntypes];
172+
constexpr int ntotal = (nin + nout) * ntypes;
173+
auto dummy_types = new char[ntotal];
174+
for (int it = 0; it < ntypes; ++it) {
175+
for (int iarg = 0; iarg < nin + nout; ++iarg) {
176+
int i = it * (nin + nout) + iarg;
177+
dummy_types[i] = tinker_types[it];
178+
}
179+
}
180+
auto h = api.PyUFunc_FromFuncAndData_(
181+
dummy_funcs, dummy_data, dummy_types, ntypes,
182+
nin, nout, constants::PyUFunc_None_, &(*leak)[0], "", 0);
183+
self() = reinterpret_borrow<object>((PyObject*)h);
184+
scope_.attr(name_) = self();
185+
}
186+
if (N != ptr()->nargs)
187+
pybind11_fail("ufunc: Argument count mismatch");
188+
if (dtype >= constants::NPY_USERDEF_) {
189+
if (api.PyUFunc_RegisterLoopForType_(
190+
ptr(), dtype, user.func, dtype_args, user.data) < 0)
191+
pybind11_fail("ufunc: Failed to register custom ufunc");
192+
} else {
193+
// Hack because NumPy API doesn't allow us convenience for builtin types :(
194+
if (api.PyUFunc_ReplaceLoopBySignature_(
195+
ptr(), user.func, dtype_args, nullptr) < 0)
196+
pybind11_fail("ufunc: Failed ot register builtin ufunc");
197+
// Now that we've registered, ensure that we replace the data.
198+
bool found{};
199+
for (int i = 0; i < ptr()->ntypes; ++i) {
200+
if (ptr()->functions[i] == user.func) {
201+
found = true;
202+
ptr()->data[i] = user.data;
203+
break;
204+
}
205+
}
206+
if (!found)
207+
pybind11_fail("Can't hack and slash");
208+
}
209+
}
210+
211+
// These are only used if we have something new.
212+
const char* name_{};
213+
handle scope_{};
214+
};
215+
216+
NAMESPACE_END(PYBIND11_NAMESPACE)

0 commit comments

Comments
 (0)