27
27
28
28
#include < sycl/sycl.hpp>
29
29
30
+ #include " dpctl4pybind11.hpp"
31
+
32
+ // dpctl tensor headers
33
+ #include " utils/type_utils.hpp"
34
+
30
35
namespace dpnp ::kernels::bitwise_count
31
36
{
37
+ namespace tu_ns = dpctl::tensor::type_utils;
38
+
32
39
template <typename argT, typename resT>
33
40
struct BitwiseCountFunctor
34
41
{
@@ -37,7 +44,7 @@ struct BitwiseCountFunctor
37
44
// constant value, if constant
38
45
// constexpr resT constant_value = resT{};
39
46
// is function defined for sycl::vec
40
- using supports_vec = typename std::false_type ;
47
+ using supports_vec = typename std::true_type ;
41
48
// do both argT and resT support subgroup store/load operation
42
49
using supports_sg_loadstore = typename std::true_type;
43
50
@@ -50,5 +57,36 @@ struct BitwiseCountFunctor
50
57
return sycl::popcount (sycl::abs (x));
51
58
}
52
59
}
60
+
61
+ template <int vec_sz>
62
+ sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT, vec_sz> &x) const
63
+ {
64
+ if constexpr (std::is_unsigned_v<argT>) {
65
+ auto const &res_vec = sycl::popcount (x);
66
+
67
+ using deducedT = typename std::remove_cv_t <
68
+ std::remove_reference_t <decltype (res_vec)>>::element_type;
69
+
70
+ if constexpr (std::is_same_v<resT, deducedT>) {
71
+ return res_vec;
72
+ }
73
+ else {
74
+ return tu_ns::vec_cast<std::uint8_t , deducedT, vec_sz>(res_vec);
75
+ }
76
+ }
77
+ else {
78
+ auto const &res_vec = sycl::popcount (x);
79
+
80
+ using deducedT = typename std::remove_cv_t <
81
+ std::remove_reference_t <decltype (res_vec)>>::element_type;
82
+
83
+ if constexpr (std::is_same_v<resT, deducedT>) {
84
+ return res_vec;
85
+ }
86
+ else {
87
+ return tu_ns::vec_cast<std::uint8_t , deducedT, vec_sz>(res_vec);
88
+ }
89
+ }
90
+ }
53
91
};
54
92
} // namespace dpnp::kernels::bitwise_count
0 commit comments