@@ -72,7 +72,7 @@ struct BincountEdges
72
72
template <typename dT>
73
73
bool in_bounds (const dT *val, const boundsT &bounds) const
74
74
{
75
- return check_in_bounds (val[0 ], std::get<0 >(bounds),
75
+ return check_in_bounds (static_cast <T>( val[0 ]) , std::get<0 >(bounds),
76
76
std::get<1 >(bounds));
77
77
}
78
78
@@ -81,13 +81,15 @@ struct BincountEdges
81
81
T max;
82
82
};
83
83
84
- template <typename T, typename HistType = size_t >
84
+ using DefaultHistType = int64_t ;
85
+
86
+ template <typename T, typename HistType = DefaultHistType>
85
87
struct BincountF
86
88
{
87
89
static sycl::event impl (sycl::queue &exec_q,
88
90
const void *vin,
89
- const int64_t min,
90
- const int64_t max,
91
+ const uint64_t min,
92
+ const uint64_t max,
91
93
const void *vweights,
92
94
void *vout,
93
95
const size_t ,
@@ -145,9 +147,12 @@ struct BincountF
145
147
}
146
148
};
147
149
148
- using SupportedTypes = std::tuple<std::tuple<int64_t , int64_t >,
150
+ using SupportedTypes = std::tuple<std::tuple<int64_t , DefaultHistType>,
151
+ std::tuple<uint64_t , DefaultHistType>,
149
152
std::tuple<int64_t , float >,
150
- std::tuple<int64_t , double >>;
153
+ std::tuple<uint64_t , float >,
154
+ std::tuple<int64_t , double >,
155
+ std::tuple<uint64_t , double >>;
151
156
152
157
} // namespace
153
158
@@ -158,8 +163,8 @@ Bincount::Bincount() : dispatch_table("sample", "histogram")
158
163
159
164
std::tuple<sycl::event, sycl::event> Bincount::call (
160
165
const dpctl::tensor::usm_ndarray &sample,
161
- const int64_t min,
162
- const int64_t max,
166
+ const uint64_t min,
167
+ const uint64_t max,
163
168
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
164
169
dpctl::tensor::usm_ndarray &histogram,
165
170
const std::vector<sycl::event> &depends)
0 commit comments