Skip to content

Commit 220cf9a

Browse files
committed
add vector implementation of the kernel
1 parent b5fc64a commit 220cf9a

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

dpnp/backend/kernels/elementwise_functions/bitwise_count.hpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,15 @@
2727

2828
#include <sycl/sycl.hpp>
2929

30+
#include "dpctl4pybind11.hpp"
31+
32+
// dpctl tensor headers
33+
#include "utils/type_utils.hpp"
34+
3035
namespace dpnp::kernels::bitwise_count
3136
{
37+
namespace tu_ns = dpctl::tensor::type_utils;
38+
3239
template <typename argT, typename resT>
3340
struct BitwiseCountFunctor
3441
{
@@ -37,7 +44,7 @@ struct BitwiseCountFunctor
3744
// constant value, if constant
3845
// constexpr resT constant_value = resT{};
3946
// is function defined for sycl::vec
40-
using supports_vec = typename std::false_type;
47+
using supports_vec = typename std::true_type;
4148
// do both argT and resT support subgroup store/load operation
4249
using supports_sg_loadstore = typename std::true_type;
4350

@@ -50,5 +57,36 @@ struct BitwiseCountFunctor
5057
return sycl::popcount(sycl::abs(x));
5158
}
5259
}
60+
61+
template <int vec_sz>
62+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &x) const
63+
{
64+
if constexpr (std::is_unsigned_v<argT>) {
65+
auto const &res_vec = sycl::popcount(x);
66+
67+
using deducedT = typename std::remove_cv_t<
68+
std::remove_reference_t<decltype(res_vec)>>::element_type;
69+
70+
if constexpr (std::is_same_v<resT, deducedT>) {
71+
return res_vec;
72+
}
73+
else {
74+
return tu_ns::vec_cast<std::uint8_t, deducedT, vec_sz>(res_vec);
75+
}
76+
}
77+
else {
78+
auto const &res_vec = sycl::popcount(x);
79+
80+
using deducedT = typename std::remove_cv_t<
81+
std::remove_reference_t<decltype(res_vec)>>::element_type;
82+
83+
if constexpr (std::is_same_v<resT, deducedT>) {
84+
return res_vec;
85+
}
86+
else {
87+
return tu_ns::vec_cast<std::uint8_t, deducedT, vec_sz>(res_vec);
88+
}
89+
}
90+
}
5391
};
5492
} // namespace dpnp::kernels::bitwise_count

0 commit comments

Comments
 (0)