Skip to content

Commit ba09dd8

Browse files
authored
Disable vectorized bitwise_invert for boolean inputs (#1681)
* Removes `sycl::vec` overload in `BitwiseInvertFunctor` This overload would cause sufficiently large boolean arrays to produce unexpected results when cast to another type * Adds a test for fixed bitwise_invert behavior * Re-enable vectorized `bitwise_invert` for integer types
1 parent c994666 commit ba09dd8

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ template <typename argT, typename resT> struct BitwiseInvertFunctor
5858

5959
using is_constant = typename std::false_type;
6060
// constexpr resT constant_value = resT{};
61-
using supports_vec = typename std::true_type;
61+
using supports_vec = typename std::negation<std::is_same<argT, bool>>;
6262
using supports_sg_loadstore = typename std::true_type;
63-
;
6463

6564
resT operator()(const argT &in) const
6665
{
@@ -75,16 +74,7 @@ template <typename argT, typename resT> struct BitwiseInvertFunctor
7574
template <int vec_sz>
7675
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in) const
7776
{
78-
if constexpr (std::is_same_v<argT, bool>) {
79-
auto res_vec = !in;
80-
81-
using deducedT = typename std::remove_cv_t<
82-
std::remove_reference_t<decltype(res_vec)>>::element_type;
83-
return vec_cast<resT, deducedT, vec_sz>(res_vec);
84-
}
85-
else {
86-
return ~in;
87-
}
77+
return ~in;
8878
}
8979
};
9080

dpctl/tests/elementwise/test_bitwise_invert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,13 @@ def test_bitwise_invert_order():
117117
ar1 = dpt.zeros((40, 40), dtype="i4", order="C")[:20, ::-2].mT
118118
r4 = dpt.bitwise_invert(ar1, order="K")
119119
assert r4.strides == (-1, 20)
120+
121+
122+
def test_bitwise_invert_large_boolean():
123+
get_queue_or_skip()
124+
125+
x = dpt.tril(dpt.ones((32, 32), dtype="?"), k=-1)
126+
res = dpt.astype(dpt.bitwise_invert(x), "i4")
127+
128+
assert dpt.all(res >= 0)
129+
assert dpt.all(res <= 1)

0 commit comments

Comments
 (0)