Skip to content

Commit 3b0fa4c

Browse files
committed
Add uint64 as supported sample type
1 parent b91e43b commit 3b0fa4c

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

dpnp/backend/extensions/statistics/bincount.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct BincountEdges
7272
template <typename dT>
7373
bool in_bounds(const dT *val, const boundsT &bounds) const
7474
{
75-
return check_in_bounds(val[0], std::get<0>(bounds),
75+
return check_in_bounds(static_cast<T>(val[0]), std::get<0>(bounds),
7676
std::get<1>(bounds));
7777
}
7878

@@ -81,13 +81,15 @@ struct BincountEdges
8181
T max;
8282
};
8383

84-
template <typename T, typename HistType = size_t>
84+
using DefaultHistType = int64_t;
85+
86+
template <typename T, typename HistType = DefaultHistType>
8587
struct BincountF
8688
{
8789
static sycl::event impl(sycl::queue &exec_q,
8890
const void *vin,
89-
const int64_t min,
90-
const int64_t max,
91+
const uint64_t min,
92+
const uint64_t max,
9193
const void *vweights,
9294
void *vout,
9395
const size_t,
@@ -145,9 +147,12 @@ struct BincountF
145147
}
146148
};
147149

148-
using SupportedTypes = std::tuple<std::tuple<int64_t, int64_t>,
150+
using SupportedTypes = std::tuple<std::tuple<int64_t, DefaultHistType>,
151+
std::tuple<uint64_t, DefaultHistType>,
149152
std::tuple<int64_t, float>,
150-
std::tuple<int64_t, double>>;
153+
std::tuple<uint64_t, float>,
154+
std::tuple<int64_t, double>,
155+
std::tuple<uint64_t, double>>;
151156

152157
} // namespace
153158

@@ -158,8 +163,8 @@ Bincount::Bincount() : dispatch_table("sample", "histogram")
158163

159164
std::tuple<sycl::event, sycl::event> Bincount::call(
160165
const dpctl::tensor::usm_ndarray &sample,
161-
const int64_t min,
162-
const int64_t max,
166+
const uint64_t min,
167+
const uint64_t max,
163168
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
164169
dpctl::tensor::usm_ndarray &histogram,
165170
const std::vector<sycl::event> &depends)

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ struct Bincount
3939
{
4040
using FnT = sycl::event (*)(sycl::queue &,
4141
const void *,
42-
const int64_t,
43-
const int64_t,
42+
const uint64_t,
43+
const uint64_t,
4444
const void *,
4545
void *,
4646
const size_t,
@@ -53,8 +53,8 @@ struct Bincount
5353

5454
std::tuple<sycl::event, sycl::event>
5555
call(const dpctl::tensor::usm_ndarray &input,
56-
const int64_t min,
57-
const int64_t max,
56+
const uint64_t min,
57+
const uint64_t max,
5858
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
5959
dpctl::tensor::usm_ndarray &output,
6060
const std::vector<sycl::event> &depends);

0 commit comments

Comments
 (0)