Skip to content

Commit 719532b

Browse files
eigen: Disable dtype=object arrays from being referenced
1 parent 54769e6 commit 719532b

File tree

4 files changed

+152
-72
lines changed

4 files changed

+152
-72
lines changed

include/pybind11/eigen.h

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,8 @@ struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>>
113113
template <typename PlainObjectType, int Options, typename StrideType>
114114
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
115115

116-
template <typename Scalar> bool is_pyobject_() {
117-
return static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
118-
}
116+
template <typename Scalar>
117+
using is_pyobject_dtype = std::is_base_of<npy_format_descriptor_object, npy_format_descriptor<Scalar>>;
119118

120119
// Helper struct for extracting information from an Eigen type
121120
template <typename Type_> struct EigenProps {
@@ -149,9 +148,7 @@ template <typename Type_> struct EigenProps {
149148
const auto dims = a.ndim();
150149
if (dims < 1 || dims > 2)
151150
return false;
152-
bool is_pyobject = false;
153-
if (is_pyobject_<Scalar>())
154-
is_pyobject = true;
151+
constexpr bool is_pyobject = is_pyobject_dtype<Scalar>::value;
155152
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
156153
static_cast<ssize_t>(sizeof(Scalar)));
157154
if (dims == 2) { // Matrix type: require exact match (or dynamic)
@@ -233,17 +230,27 @@ template <typename props> handle eigen_array_cast(typename props::Type const &sr
233230
src.data(), base);
234231
}
235232
else {
233+
if (base) {
234+
// Should be disabled by upstream calls to this method.
235+
// TODO(eric.cousineau): Write tests to ensure that this is not
236+
// reachable.
237+
throw cast_error(
238+
"dtype=object does not permit array referencing. "
239+
"(NOTE: this generally not be reachable, as upstream APIs "
240+
"should fail before this.");
241+
}
242+
handle empty_base{};
243+
auto policy = return_value_policy::copy;
236244
if (props::vector) {
237245
a = array(
238246
npy_format_descriptor<Scalar>::dtype(),
239247
{ (size_t) src.size() },
240248
nullptr,
241-
base
249+
empty_base
242250
);
243-
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
244251
for (ssize_t i = 0; i < src.size(); ++i) {
245252
const Scalar src_val = props::fixed_rows ? src(0, i) : src(i, 0);
246-
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src_val, policy, base));
253+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src_val, policy, empty_base));
247254
if (!value_)
248255
return handle();
249256
a.attr("itemset")(i, value_);
@@ -254,12 +261,11 @@ template <typename props> handle eigen_array_cast(typename props::Type const &sr
254261
npy_format_descriptor<Scalar>::dtype(),
255262
{(size_t) src.rows(), (size_t) src.cols()},
256263
nullptr,
257-
base
264+
empty_base
258265
);
259-
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
260266
for (ssize_t i = 0; i < src.rows(); ++i) {
261267
for (ssize_t j = 0; j < src.cols(); ++j) {
262-
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
268+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, empty_base));
263269
if (!value_)
264270
return handle();
265271
a.attr("itemset")(i, j, value_);
@@ -323,7 +329,7 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
323329
int result = 0;
324330
// Allocate the new type, then build a numpy reference into it
325331
value = Type(fits.rows, fits.cols);
326-
bool is_pyobject = is_pyobject_<Scalar>();
332+
constexpr bool is_pyobject = is_pyobject_dtype<Scalar>::value;
327333

328334
if (!is_pyobject) {
329335
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
@@ -374,22 +380,40 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
374380
// Cast implementation
375381
template <typename CType>
376382
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
377-
switch (policy) {
378-
case return_value_policy::take_ownership:
379-
case return_value_policy::automatic:
380-
return eigen_encapsulate<props>(src);
381-
case return_value_policy::move:
382-
return eigen_encapsulate<props>(new CType(std::move(*src)));
383-
case return_value_policy::copy:
384-
return eigen_array_cast<props>(*src);
385-
case return_value_policy::reference:
386-
case return_value_policy::automatic_reference:
387-
return eigen_ref_array<props>(*src);
388-
case return_value_policy::reference_internal:
389-
return eigen_ref_array<props>(*src, parent);
390-
default:
391-
throw cast_error("unhandled return_value_policy: should not happen!");
392-
};
383+
constexpr bool is_pyobject = is_pyobject_dtype<Scalar>::value;
384+
if (!is_pyobject) {
385+
switch (policy) {
386+
case return_value_policy::take_ownership:
387+
case return_value_policy::automatic:
388+
return eigen_encapsulate<props>(src);
389+
case return_value_policy::move:
390+
return eigen_encapsulate<props>(new CType(std::move(*src)));
391+
case return_value_policy::copy:
392+
return eigen_array_cast<props>(*src);
393+
case return_value_policy::reference:
394+
case return_value_policy::automatic_reference:
395+
return eigen_ref_array<props>(*src);
396+
case return_value_policy::reference_internal:
397+
return eigen_ref_array<props>(*src, parent);
398+
default:
399+
throw cast_error("unhandled return_value_policy: should not happen!");
400+
};
401+
} else {
402+
// For arrays of `dtype=object`, referencing is invalid, so we should squash that as soon as possible.
403+
switch (policy) {
404+
case return_value_policy::automatic:
405+
case return_value_policy::move:
406+
case return_value_policy::copy:
407+
case return_value_policy::automatic_reference:
408+
return eigen_array_cast<props>(*src);
409+
case return_value_policy::take_ownership:
410+
case return_value_policy::reference:
411+
case return_value_policy::reference_internal:
412+
throw cast_error("dtype=object arrays must be copied, and cannot be referenced");
413+
default:
414+
throw cast_error("unhandled return_value_policy: should not happen!");
415+
};
416+
}
393417
}
394418

395419
public:
@@ -446,6 +470,7 @@ struct return_value_policy_override<Return, enable_if_t<is_eigen_dense_map<Retur
446470
template <typename MapType> struct eigen_map_caster {
447471
private:
448472
using props = EigenProps<MapType>;
473+
using Scalar = typename props::Scalar;
449474

450475
public:
451476

@@ -456,18 +481,33 @@ template <typename MapType> struct eigen_map_caster {
456481
// that this means you need to ensure you don't destroy the object in some other way (e.g. with
457482
// an appropriate keep_alive, or with a reference to a statically allocated matrix).
458483
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
459-
switch (policy) {
460-
case return_value_policy::copy:
461-
return eigen_array_cast<props>(src);
462-
case return_value_policy::reference_internal:
463-
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
464-
case return_value_policy::reference:
465-
case return_value_policy::automatic:
466-
case return_value_policy::automatic_reference:
467-
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
468-
default:
469-
// move, take_ownership don't make any sense for a ref/map:
470-
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
484+
if (!is_pyobject_dtype<Scalar>::value) {
485+
switch (policy) {
486+
case return_value_policy::copy:
487+
return eigen_array_cast<props>(src);
488+
case return_value_policy::reference_internal:
489+
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
490+
case return_value_policy::reference:
491+
case return_value_policy::automatic:
492+
case return_value_policy::automatic_reference:
493+
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
494+
default:
495+
// move, take_ownership don't make any sense for a ref/map:
496+
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
497+
}
498+
} else {
499+
switch (policy) {
500+
case return_value_policy::copy:
501+
return eigen_array_cast<props>(src);
502+
case return_value_policy::reference_internal:
503+
case return_value_policy::reference:
504+
case return_value_policy::automatic:
505+
case return_value_policy::automatic_reference:
506+
throw cast_error("dtype=object arrays must be copied, and cannot be referenced");
507+
default:
508+
// move, take_ownership don't make any sense for a ref/map:
509+
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
510+
}
471511
}
472512
}
473513

@@ -519,9 +559,14 @@ struct type_caster<
519559
bool need_copy = !isinstance<Array>(src);
520560

521561
EigenConformable<props::row_major> fits;
522-
bool is_pyobject = false;
523-
if (is_pyobject_<Scalar>()) {
524-
is_pyobject = true;
562+
constexpr bool is_pyobject = is_pyobject_dtype<Scalar>::value;
563+
// TODO(eric.cousineau): Make this compile-time once Drake does not use this in any code
564+
// for scalar types.
565+
//static_assert(!(is_pyobject && need_writeable), "dtype=object cannot provide writeable references");
566+
if (is_pyobject && need_writeable) {
567+
throw cast_error("dtype=object cannot provide writeable references");
568+
}
569+
if (is_pyobject) {
525570
need_copy = true;
526571
}
527572
if (!need_copy) {

include/pybind11/numpy.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,19 +1230,22 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12301230
(::std::vector<::pybind11::detail::field_descriptor> \
12311231
{PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12321232

1233+
struct npy_format_descriptor_object {
1234+
public:
1235+
enum { value = npy_api::NPY_OBJECT_ };
1236+
static pybind11::dtype dtype() {
1237+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) {
1238+
return reinterpret_borrow<pybind11::dtype>(ptr);
1239+
}
1240+
pybind11_fail("Unsupported buffer format!");
1241+
}
1242+
static constexpr auto name = _("object");
1243+
};
1244+
12331245
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
12341246
namespace pybind11 { namespace detail { \
1235-
template <> struct npy_format_descriptor<Type> { \
1236-
public: \
1237-
enum { value = npy_api::NPY_OBJECT_ }; \
1238-
static pybind11::dtype dtype() { \
1239-
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1240-
return reinterpret_borrow<pybind11::dtype>(ptr); \
1241-
} \
1242-
pybind11_fail("Unsupported buffer format!"); \
1243-
} \
1244-
static constexpr auto name = _("object"); \
1245-
}; \
1247+
template <> struct npy_format_descriptor<Type> : \
1248+
public npy_format_descriptor_object {}; \
12461249
}}
12471250

12481251
#endif // __CLION_IDE__

tests/test_eigen.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ void reset_refs() {
6161
reset_ref(get_rm());
6262
}
6363

64+
VectorXADScalar& get_cm_adscalar() {
65+
static VectorXADScalar value(1);
66+
return value;
67+
};
68+
VectorXADScalarR& get_rm_adscalar() {
69+
static VectorXADScalarR value(1);
70+
return value;
71+
};
72+
73+
6474
// Returns element 2,1 from a matrix (used to test copy/nocopy)
6575
double get_elem(Eigen::Ref<const Eigen::MatrixXd> m) { return m(2, 1); };
6676

@@ -106,9 +116,7 @@ TEST_SUBMODULE(eigen, m) {
106116
m.def("double_adscalar_row", [](const VectorXADScalarR &x) -> VectorXADScalarR { return 2.0f * x; });
107117
m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; });
108118
m.def("double_threec", [](py::EigenDRef<Eigen::Vector3f> x) { x *= 2; });
109-
m.def("double_adscalarc", [](py::EigenDRef<VectorXADScalar> x) { x *= 2; });
110119
m.def("double_threer", [](py::EigenDRef<Eigen::RowVector3f> x) { x *= 2; });
111-
m.def("double_adscalarr", [](py::EigenDRef<VectorXADScalarR> x) { x *= 2; });
112120
m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; });
113121
m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; });
114122

@@ -130,6 +138,8 @@ TEST_SUBMODULE(eigen, m) {
130138
// Mutators (Eigen maps into numpy variables):
131139
m.def("add_rm", add_rm); // Only takes row-contiguous
132140
m.def("add_cm", add_cm); // Only takes column-contiguous
141+
m.def("add_rm_adscalar", [](py::EigenDRef<VectorXADScalarR> x) { x.array() += 2; });
142+
m.def("add_cm_adscalar", [](py::EigenDRef<VectorXADScalar> x) { x.array() += 2; });
133143
// Overloaded versions that will accept either row or column contiguous:
134144
m.def("add1", add_rm);
135145
m.def("add1", add_cm);
@@ -141,9 +151,17 @@ TEST_SUBMODULE(eigen, m) {
141151
// Return mutable references (numpy maps into eigen variables)
142152
m.def("get_cm_ref", []() { return Eigen::Ref<Eigen::MatrixXd>(get_cm()); });
143153
m.def("get_rm_ref", []() { return Eigen::Ref<MatrixXdR>(get_rm()); });
154+
m.def("get_cm_ref_adscalar", []() {
155+
return py::EigenDRef<VectorXADScalar>(get_cm_adscalar());
156+
});
157+
m.def("get_rm_ref_adscalar", []() {
158+
return py::EigenDRef<VectorXADScalarR>(get_rm_adscalar());
159+
});
144160
// The same references, but non-mutable (numpy maps into eigen variables, but is !writeable)
145161
m.def("get_cm_const_ref", []() { return Eigen::Ref<const Eigen::MatrixXd>(get_cm()); });
146162
m.def("get_rm_const_ref", []() { return Eigen::Ref<const MatrixXdR>(get_rm()); });
163+
m.def("get_cm_const_ref_adscalar", []() { return Eigen::Ref<const VectorXADScalar>(get_cm_adscalar()); });
164+
m.def("get_rm_const_ref_adscalar", []() { return Eigen::Ref<const VectorXADScalarR>(get_rm_adscalar()); });
147165

148166
m.def("reset_refs", reset_refs); // Restores get_{cm,rm}_ref to original values
149167

@@ -153,11 +171,12 @@ TEST_SUBMODULE(eigen, m) {
153171
return m;
154172
}, py::return_value_policy::reference);
155173

156-
// Increments ADScalar Matrix
157-
m.def("incr_adscalar_matrix", [](Eigen::Ref<DenseADScalarMatrixC> m, double v) {
158-
m += DenseADScalarMatrixC::Constant(m.rows(), m.cols(), v);
159-
return m;
160-
}, py::return_value_policy::reference);
174+
// Increments ADScalar Matrix, returns a copy.
175+
m.def("incr_adscalar_matrix", [](const Eigen::Ref<const DenseADScalarMatrixC>& m, double v) {
176+
DenseADScalarMatrixC out = m;
177+
out.array() += v;
178+
return out;
179+
});
161180

162181
// Same, but accepts a matrix of any strides
163182
m.def("incr_matrix_any", [](py::EigenDRef<Eigen::MatrixXd> m, double v) {
@@ -347,10 +366,7 @@ TEST_SUBMODULE(eigen, m) {
347366
m.def("cpp_matrix_shape", [](const MatrixX<ADScalar>& A) {
348367
return py::make_tuple(A.rows(), A.cols());
349368
});
350-
// TODO(eric.cousineau): Unless `dtype=ADScalar` (user-defined) and not
351-
// `dtype=object`, we should kill any usages of `Eigen::Ref<>` or any
352-
// const-references.
353-
m.def("cpp_matrix_shape_ref", [](const Eigen::Ref<MatrixX<ADScalar>>& A) {
369+
m.def("cpp_matrix_shape_ref", [](const Eigen::Ref<const MatrixX<ADScalar>>& A) {
354370
return py::make_tuple(A.rows(), A.cols());
355371
});
356372

tests/test_eigen.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,31 @@ def test_eigen_passing_adscalar():
182182
incremented_adscalar_mat = conv_double_to_adscalar(m.incr_adscalar_matrix(adscalar_mat, 7.),
183183
vice_versa=True)
184184
np.testing.assert_array_equal(incremented_adscalar_mat, ref + 7)
185-
# The original adscalar_mat remains unchanged in spite of passing by reference.
185+
# The original adscalar_mat remains unchanged in spite of passing by reference, since
186+
# `Eigen::Ref<const CType>` permits copying, and copying is the only valid operation for
187+
# `dtype=object`.
186188
np.testing.assert_array_equal(conv_double_to_adscalar(adscalar_mat, vice_versa=True), ref)
187189

188-
# Changes in Python are not reflected in C++ when internal_reference is returned
190+
# Changes in Python are not reflected in C++ when internal_reference is returned.
191+
# These conversions should be disabled at runtime.
192+
193+
def expect_ref_error(func):
194+
with pytest.raises(RuntimeError) as excinfo:
195+
func()
196+
assert "dtype=object" in str(excinfo)
197+
assert "reachable" not in str(excinfo)
198+
199+
# - Return arguments.
200+
expect_ref_error(lambda: m.get_cm_ref_adscalar())
201+
expect_ref_error(lambda: m.get_rm_ref_adscalar())
202+
expect_ref_error(lambda: m.get_cm_const_ref_adscalar())
203+
expect_ref_error(lambda: m.get_rm_const_ref_adscalar())
204+
# - - Mutable lvalues referenced via `reference_internal`.
189205
return_tester = m.ReturnTester()
190-
a = return_tester.get_ADScalarMat()
191-
a[1, 1] = m.AutoDiffXd(4, np.ones(1))
192-
b = return_tester.get_ADScalarMat()
193-
assert(np.isclose(b[1, 1].value(), 7.))
206+
expect_ref_error(lambda: return_tester.get_ADScalarMat())
207+
# - Input arguments, writeable `Ref<>`s.
208+
expect_ref_error(lambda: m.add_cm_adscalar(adscalar_vec_col))
209+
expect_ref_error(lambda: m.add_rm_adscalar(adscalar_vec_row))
194210

195211
# Checking Issue 1105
196212
assert m.iss1105_col_obj(adscalar_vec_col[:, None])

0 commit comments

Comments
 (0)