|
| 1 | +/* |
| 2 | + pybind11/detail/numpy_ufunc.h: Simple glue for Python UFuncs |
| 3 | +
|
| 4 | + Copyright (c) 2018 Eric Cousineau <[email protected]> |
| 5 | +
|
| 6 | + All rights reserved. Use of this source code is governed by a |
| 7 | + BSD-style license that can be found in the LICENSE file. |
| 8 | +*/ |
| 9 | + |
| 10 | +#pragma once |
| 11 | + |
| 12 | +#include "../numpy.h" |
| 13 | +#include "inference.h" |
| 14 | + |
| 15 | +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) |
| 16 | +NAMESPACE_BEGIN(detail) |
| 17 | + |
| 18 | +// Utilities |
| 19 | + |
| 20 | +// Builtins registered using numpy/build/{...}/numpy/core/include/numpy/__umath_generated.c |
| 21 | + |
| 22 | +template <typename... Args> |
| 23 | +struct ufunc_ptr { |
| 24 | + PyUFuncGenericFunction func{}; |
| 25 | + void* data{}; |
| 26 | +}; |
| 27 | + |
| 28 | +// Unary ufunc. |
| 29 | +template <typename Arg0, typename Out, typename Func> |
| 30 | +auto ufunc_to_ptr(Func func, type_pack<Arg0, Out>) { |
| 31 | + auto ufunc = []( |
| 32 | + char** args, npy_intp* dimensions, npy_intp* steps, void* data) { |
| 33 | + Func& func = *(Func*)data; |
| 34 | + int step_0 = steps[0]; |
| 35 | + int step_out = steps[1]; |
| 36 | + int n = *dimensions; |
| 37 | + char *in_0 = args[0], *out = args[1]; |
| 38 | + for (int k = 0; k < n; k++) { |
| 39 | + // TODO(eric.cousineau): Support pointers being changed. |
| 40 | + *(Out*)out = func(*(Arg0*)in_0); |
| 41 | + in_0 += step_0; |
| 42 | + out += step_out; |
| 43 | + } |
| 44 | + }; |
| 45 | + // N.B. `new Func(...)` will never be destroyed. |
| 46 | + return ufunc_ptr<Arg0, Out>{ufunc, new Func(func)}; |
| 47 | +} |
| 48 | + |
| 49 | +// Binary ufunc. |
| 50 | +template <typename Arg0, typename Arg1, typename Out, typename Func = void> |
| 51 | +auto ufunc_to_ptr(Func func, type_pack<Arg0, Arg1, Out>) { |
| 52 | + auto ufunc = [](char** args, npy_intp* dimensions, npy_intp* steps, void* data) { |
| 53 | + Func& func = *(Func*)data; |
| 54 | + int step_0 = steps[0]; |
| 55 | + int step_1 = steps[1]; |
| 56 | + int step_out = steps[2]; |
| 57 | + int n = *dimensions; |
| 58 | + char *in_0 = args[0], *in_1 = args[1], *out = args[2]; |
| 59 | + for (int k = 0; k < n; k++) { |
| 60 | + // TODO(eric.cousineau): Support pointers being fed in. |
| 61 | + *(Out*)out = func(*(Arg0*)in_0, *(Arg1*)in_1); |
| 62 | + in_0 += step_0; |
| 63 | + in_1 += step_1; |
| 64 | + out += step_out; |
| 65 | + } |
| 66 | + }; |
| 67 | + // N.B. `new Func(...)` will never be destroyed. |
| 68 | + return ufunc_ptr<Arg0, Arg1, Out>{ufunc, new Func(func)}; |
| 69 | +} |
| 70 | + |
| 71 | +// Generic dispatch. |
| 72 | +template <typename Func> |
| 73 | +auto ufunc_to_ptr(Func func) { |
| 74 | + auto info = detail::function_inference::run(func); |
| 75 | + using Info = decltype(info); |
| 76 | + auto type_args = type_pack_apply<std::decay_t>( |
| 77 | + type_pack_concat( |
| 78 | + typename Info::Args{}, |
| 79 | + type_pack<typename Info::Return>{})); |
| 80 | + return ufunc_to_ptr(func, type_args); |
| 81 | +} |
| 82 | + |
| 83 | +template <typename From, typename To, typename Func> |
| 84 | +void ufunc_register_cast( |
| 85 | + Func&& func, bool allow_coercion, type_pack<From, To> = {}) { |
| 86 | + static auto cast_lambda = detail::function_inference::run(func).func; |
| 87 | + auto cast_func = +[]( |
| 88 | + void* from_, void* to_, npy_intp n, |
| 89 | + void* fromarr, void* toarr) { |
| 90 | + const From* from = (From*)from_; |
| 91 | + To* to = (To*)to_; |
| 92 | + for (npy_intp i = 0; i < n; i++) |
| 93 | + to[i] = cast_lambda(from[i]); |
| 94 | + }; |
| 95 | + auto& api = npy_api::get(); |
| 96 | + auto from = npy_format_descriptor<From>::dtype(); |
| 97 | + int to_num = npy_format_descriptor<To>::dtype().num(); |
| 98 | + auto from_raw = (PyArray_Descr*)from.ptr(); |
| 99 | + if (api.PyArray_RegisterCastFunc_(from_raw, to_num, cast_func) < 0) |
| 100 | + pybind11_fail("ufunc: Cannot register cast"); |
| 101 | + if (allow_coercion) { |
| 102 | + if (api.PyArray_RegisterCanCast_( |
| 103 | + from_raw, to_num, npy_api::NPY_NOSCALAR_) < 0) |
| 104 | + pybind11_fail( |
| 105 | + "ufunc: Cannot register implicit / coercion cast capability"); |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +NAMESPACE_END(detail) |
| 110 | + |
| 111 | +class ufunc : public object { |
| 112 | +public: |
| 113 | + ufunc(object ptr) : object(ptr) { |
| 114 | + // TODO(eric.cousineau): Check type. |
| 115 | + } |
| 116 | + |
| 117 | + ufunc(detail::PyUFuncObject* ptr) |
| 118 | + : object(reinterpret_borrow<object>((PyObject*)ptr)) |
| 119 | + {} |
| 120 | + |
| 121 | + ufunc(handle scope, const char* name) : scope_{scope}, name_{name} {} |
| 122 | + |
| 123 | + // Gets a NumPy UFunc by name. |
| 124 | + static ufunc get_builtin(const char* name) { |
| 125 | + module numpy = module::import("numpy"); |
| 126 | + return ufunc(numpy.attr(name)); |
| 127 | + } |
| 128 | + |
| 129 | + template <typename Type, typename Func> |
| 130 | + ufunc& def_loop(Func func) { |
| 131 | + do_register<Type>(detail::ufunc_to_ptr(func)); |
| 132 | + return *this; |
| 133 | + } |
| 134 | + |
| 135 | + detail::PyUFuncObject* ptr() const { |
| 136 | + return (detail::PyUFuncObject*)self().ptr(); |
| 137 | + } |
| 138 | + |
| 139 | +private: |
| 140 | + object& self() { return *this; } |
| 141 | + const object& self() const { return *this; } |
| 142 | + |
| 143 | + // Registers a function pointer as a UFunc, mapping types to dtype nums. |
| 144 | + template <typename Type, typename ... Args> |
| 145 | + void do_register(detail::ufunc_ptr<Args...> user) { |
| 146 | + constexpr int N = sizeof...(Args); |
| 147 | + constexpr int nin = N - 1; |
| 148 | + constexpr int nout = 1; |
| 149 | + int dtype = dtype::of<Type>().num(); |
| 150 | + int dtype_args[] = {dtype::of<Args>().num()...}; |
| 151 | + // Determine if we need to make a new ufunc. |
| 152 | + using constants = detail::npy_api::constants; |
| 153 | + auto& api = detail::npy_api::get(); |
| 154 | + if (!self()) { |
| 155 | + if (!name_) |
| 156 | + pybind11_fail("dtype: unspecified name"); |
| 157 | + // TODO(eric.cousineau): Fix unfreed memory with `name`. |
| 158 | + auto leak = new std::string(name_); |
| 159 | + // The following dummy stuff is to allow monkey-patching existing ufuncs. |
| 160 | + // This is a bit sketchy, as calling the wrong thing may cause a segfault. |
| 161 | + // TODO(eric.cousineau): Figure out how to more elegantly specify preallocation... |
| 162 | + // Preallocate to allow replacement? |
| 163 | + constexpr int ntypes = 4; |
| 164 | + static char tinker_types[ntypes] = { |
| 165 | + constants::NPY_BOOL_, |
| 166 | + constants::NPY_INT_, |
| 167 | + constants::NPY_FLOAT_, |
| 168 | + constants::NPY_DOUBLE_, |
| 169 | + }; |
| 170 | + auto dummy_funcs = new detail::PyUFuncGenericFunction[ntypes]; |
| 171 | + auto dummy_data = new void*[ntypes]; |
| 172 | + constexpr int ntotal = (nin + nout) * ntypes; |
| 173 | + auto dummy_types = new char[ntotal]; |
| 174 | + for (int it = 0; it < ntypes; ++it) { |
| 175 | + for (int iarg = 0; iarg < nin + nout; ++iarg) { |
| 176 | + int i = it * (nin + nout) + iarg; |
| 177 | + dummy_types[i] = tinker_types[it]; |
| 178 | + } |
| 179 | + } |
| 180 | + auto h = api.PyUFunc_FromFuncAndData_( |
| 181 | + dummy_funcs, dummy_data, dummy_types, ntypes, |
| 182 | + nin, nout, constants::PyUFunc_None_, &(*leak)[0], "", 0); |
| 183 | + self() = reinterpret_borrow<object>((PyObject*)h); |
| 184 | + scope_.attr(name_) = self(); |
| 185 | + } |
| 186 | + if (N != ptr()->nargs) |
| 187 | + pybind11_fail("ufunc: Argument count mismatch"); |
| 188 | + if (dtype >= constants::NPY_USERDEF_) { |
| 189 | + if (api.PyUFunc_RegisterLoopForType_( |
| 190 | + ptr(), dtype, user.func, dtype_args, user.data) < 0) |
| 191 | + pybind11_fail("ufunc: Failed to register custom ufunc"); |
| 192 | + } else { |
| 193 | + // Hack because NumPy API doesn't allow us convenience for builtin types :( |
| 194 | + if (api.PyUFunc_ReplaceLoopBySignature_( |
| 195 | + ptr(), user.func, dtype_args, nullptr) < 0) |
| 196 | + pybind11_fail("ufunc: Failed ot register builtin ufunc"); |
| 197 | + // Now that we've registered, ensure that we replace the data. |
| 198 | + bool found{}; |
| 199 | + for (int i = 0; i < ptr()->ntypes; ++i) { |
| 200 | + if (ptr()->functions[i] == user.func) { |
| 201 | + found = true; |
| 202 | + ptr()->data[i] = user.data; |
| 203 | + break; |
| 204 | + } |
| 205 | + } |
| 206 | + if (!found) |
| 207 | + pybind11_fail("Can't hack and slash"); |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + // These are only used if we have something new. |
| 212 | + const char* name_{}; |
| 213 | + handle scope_{}; |
| 214 | +}; |
| 215 | + |
| 216 | +NAMESPACE_END(PYBIND11_NAMESPACE) |
0 commit comments