Skip to content

Commit 3955e32

Browse files
authored
Merge branch 'master' into impl-bartlett
2 parents 9434ebd + acd33e4 commit 3955e32

File tree

7 files changed

+233
-95
lines changed

7 files changed

+233
-95
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515

1616
### Changed
1717

18+
* Allowed input array of `uint64` dtype in `dpnp.bincount` [#2361](https://github.com/IntelPython/dpnp/pull/2361)
19+
1820
### Fixed
1921

2022

dpnp/backend/extensions/statistics/bincount.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct BincountEdges
7272
template <typename dT>
7373
bool in_bounds(const dT *val, const boundsT &bounds) const
7474
{
75-
return check_in_bounds(val[0], std::get<0>(bounds),
75+
return check_in_bounds(static_cast<T>(val[0]), std::get<0>(bounds),
7676
std::get<1>(bounds));
7777
}
7878

@@ -81,16 +81,17 @@ struct BincountEdges
8181
T max;
8282
};
8383

84-
template <typename T, typename HistType = size_t>
84+
using DefaultHistType = int64_t;
85+
86+
template <typename T, typename HistType = DefaultHistType>
8587
struct BincountF
8688
{
8789
static sycl::event impl(sycl::queue &exec_q,
8890
const void *vin,
89-
const int64_t min,
90-
const int64_t max,
91+
const uint64_t min,
92+
const uint64_t max,
9193
const void *vweights,
9294
void *vout,
93-
const size_t,
9495
const size_t size,
9596
const std::vector<sycl::event> &depends)
9697
{
@@ -145,9 +146,12 @@ struct BincountF
145146
}
146147
};
147148

148-
using SupportedTypes = std::tuple<std::tuple<int64_t, int64_t>,
149+
using SupportedTypes = std::tuple<std::tuple<int64_t, DefaultHistType>,
150+
std::tuple<uint64_t, DefaultHistType>,
149151
std::tuple<int64_t, float>,
150-
std::tuple<int64_t, double>>;
152+
std::tuple<uint64_t, float>,
153+
std::tuple<int64_t, double>,
154+
std::tuple<uint64_t, double>>;
151155

152156
} // namespace
153157

@@ -158,8 +162,8 @@ Bincount::Bincount() : dispatch_table("sample", "histogram")
158162

159163
std::tuple<sycl::event, sycl::event> Bincount::call(
160164
const dpctl::tensor::usm_ndarray &sample,
161-
const int64_t min,
162-
const int64_t max,
165+
const uint64_t min,
166+
const uint64_t max,
163167
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
164168
dpctl::tensor::usm_ndarray &histogram,
165169
const std::vector<sycl::event> &depends)
@@ -182,8 +186,7 @@ std::tuple<sycl::event, sycl::event> Bincount::call(
182186
weights.has_value() ? weights.value().get_data() : nullptr;
183187

184188
auto ev = bincount_func(exec_q, sample.get_data(), min, max, weights_ptr,
185-
histogram.get_data(), histogram.get_shape(0),
186-
sample.get_shape(0), depends);
189+
histogram.get_data(), sample.get_size(), depends);
187190

188191
sycl::event args_ev;
189192
if (weights.has_value()) {

dpnp/backend/extensions/statistics/bincount.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ struct Bincount
3939
{
4040
using FnT = sycl::event (*)(sycl::queue &,
4141
const void *,
42-
const int64_t,
43-
const int64_t,
42+
const uint64_t,
43+
const uint64_t,
4444
const void *,
4545
void *,
4646
const size_t,
47-
const size_t,
4847
const std::vector<sycl::event> &);
4948

5049
common::DispatchTable2<FnT> dispatch_table;
@@ -53,8 +52,8 @@ struct Bincount
5352

5453
std::tuple<sycl::event, sycl::event>
5554
call(const dpctl::tensor::usm_ndarray &input,
56-
const int64_t min,
57-
const int64_t max,
55+
const uint64_t min,
56+
const uint64_t max,
5857
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
5958
dpctl::tensor::usm_ndarray &output,
6059
const std::vector<sycl::event> &depends);

dpnp/backend/extensions/statistics/histogram.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ using CachedEdges = HistogramEdges<T, CachedData<const T, 1>>;
106106
template <typename T>
107107
using UncachedEdges = HistogramEdges<T, UncachedData<const T, 1>>;
108108

109-
template <typename T, typename BinsT, typename HistType = size_t>
109+
using DefaultHistType = int64_t;
110+
111+
template <typename T, typename BinsT, typename HistType = DefaultHistType>
110112
struct HistogramF
111113
{
112114
static sycl::event impl(sycl::queue &exec_q,
@@ -186,26 +188,27 @@ using HistogramF_ = HistogramF<SampleType, SampleType, HistType>;
186188

187189
} // namespace
188190

189-
using SupportedTypes = std::tuple<std::tuple<uint64_t, int64_t>,
190-
std::tuple<int64_t, int64_t>,
191-
std::tuple<uint64_t, float>,
192-
std::tuple<int64_t, float>,
193-
std::tuple<uint64_t, double>,
194-
std::tuple<int64_t, double>,
195-
std::tuple<uint64_t, std::complex<float>>,
196-
std::tuple<int64_t, std::complex<float>>,
197-
std::tuple<uint64_t, std::complex<double>>,
198-
std::tuple<int64_t, std::complex<double>>,
199-
std::tuple<float, int64_t>,
200-
std::tuple<double, int64_t>,
201-
std::tuple<float, float>,
202-
std::tuple<double, double>,
203-
std::tuple<float, std::complex<float>>,
204-
std::tuple<double, std::complex<double>>,
205-
std::tuple<std::complex<float>, int64_t>,
206-
std::tuple<std::complex<double>, int64_t>,
207-
std::tuple<std::complex<float>, float>,
208-
std::tuple<std::complex<double>, double>>;
191+
using SupportedTypes =
192+
std::tuple<std::tuple<uint64_t, DefaultHistType>,
193+
std::tuple<int64_t, DefaultHistType>,
194+
std::tuple<uint64_t, float>,
195+
std::tuple<int64_t, float>,
196+
std::tuple<uint64_t, double>,
197+
std::tuple<int64_t, double>,
198+
std::tuple<uint64_t, std::complex<float>>,
199+
std::tuple<int64_t, std::complex<float>>,
200+
std::tuple<uint64_t, std::complex<double>>,
201+
std::tuple<int64_t, std::complex<double>>,
202+
std::tuple<float, DefaultHistType>,
203+
std::tuple<double, DefaultHistType>,
204+
std::tuple<float, float>,
205+
std::tuple<double, double>,
206+
std::tuple<float, std::complex<float>>,
207+
std::tuple<double, std::complex<double>>,
208+
std::tuple<std::complex<float>, DefaultHistType>,
209+
std::tuple<std::complex<double>, DefaultHistType>,
210+
std::tuple<std::complex<float>, float>,
211+
std::tuple<std::complex<double>, double>>;
209212

210213
Histogram::Histogram() : dispatch_table("sample", "histogram")
211214
{

dpnp/dpnp_iface_histograms.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ def _bincount_run_native(
293293

294294
mem_ev, bc_ev = statistics_ext.bincount(
295295
x_usm,
296-
min_v,
297-
max_v,
296+
min_v.item(),
297+
max_v.item(),
298298
weights_usm,
299299
n_usm,
300300
depends=_manager.submitted_events,
@@ -313,6 +313,11 @@ def bincount(x, weights=None, minlength=0):
313313
314314
For full documentation refer to :obj:`numpy.bincount`.
315315
316+
Warning
317+
-------
318+
This function synchronizes in order to calculate binning edges.
319+
This may harm performance in some applications.
320+
316321
Parameters
317322
----------
318323
x : {dpnp.ndarray, usm_ndarray}
@@ -391,10 +396,8 @@ def bincount(x, weights=None, minlength=0):
391396

392397
if x_casted_dtype is None or ntype_casted is None: # pragma: no cover
393398
raise ValueError(
394-
f"function '{bincount}' does not support input types "
395-
f"({x.dtype}, {ntype}), "
396-
"and the inputs could not be coerced to any "
397-
"supported types"
399+
f"Input types ({x.dtype}, {ntype}) are not supported, "
400+
"and the inputs could not be coerced to any supported types"
398401
)
399402

400403
x_casted = dpnp.asarray(x, dtype=x_casted_dtype, order="C")
@@ -508,6 +511,11 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
508511
509512
For full documentation refer to :obj:`numpy.histogram`.
510513
514+
Warning
515+
-------
516+
This function may synchronize in order to check a monotonically increasing
517+
array of bin edges. This may harm performance in some applications.
518+
511519
Parameters
512520
----------
513521
a : {dpnp.ndarray, usm_ndarray}
@@ -611,9 +619,8 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
611619

612620
if a_bin_dtype is None or hist_dtype is None: # pragma: no cover
613621
raise ValueError(
614-
f"function '{histogram}' does not support input types "
615-
f"({a.dtype}, {bin_edges.dtype}, {ntype}), "
616-
"and the inputs could not be coerced to any "
622+
f"Input types ({a.dtype}, {bin_edges.dtype}, {ntype}) "
623+
"are not supported, and the inputs could not be coerced to any "
617624
"supported types"
618625
)
619626

@@ -675,6 +682,11 @@ def histogram_bin_edges(a, bins=10, range=None, weights=None):
675682
676683
For full documentation refer to :obj:`numpy.histogram_bin_edges`.
677684
685+
Warning
686+
-------
687+
This function may synchronize in order to check a monotonically increasing
688+
array of bin edges. This may harm performance in some applications.
689+
678690
Parameters
679691
----------
680692
a : {dpnp.ndarray, usm_ndarray}
@@ -760,6 +772,13 @@ def histogram2d(x, y, bins=10, range=None, density=None, weights=None):
760772
"""
761773
Compute the bi-dimensional histogram of two data samples.
762774
775+
For full documentation refer to :obj:`numpy.histogram2d`.
776+
777+
Warning
778+
-------
779+
This function may synchronize in order to check a monotonically increasing
780+
array of bin edges. This may harm performance in some applications.
781+
763782
Parameters
764783
----------
765784
x : {dpnp.ndarray, usm_ndarray} of shape (N,)
@@ -1088,6 +1107,11 @@ def histogramdd(sample, bins=10, range=None, density=None, weights=None):
10881107
10891108
For full documentation refer to :obj:`numpy.histogramdd`.
10901109
1110+
Warning
1111+
-------
1112+
This function may synchronize in order to check a monotonically increasing
1113+
array of bin edges. This may harm performance in some applications.
1114+
10911115
Parameters
10921116
----------
10931117
sample : {dpnp.ndarray, usm_ndarray}

dpnp/tests/test_histogram.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
from .helper import (
1414
assert_dtype_allclose,
15+
generate_random_numpy_array,
1516
get_abs_array,
1617
get_all_dtypes,
18+
get_complex_dtypes,
19+
get_float_complex_dtypes,
1720
get_float_dtypes,
1821
get_integer_dtypes,
1922
has_support_aspect64,
@@ -532,34 +535,29 @@ def test_range(self, range, dtype):
532535

533536

534537
class TestBincount:
535-
@pytest.mark.parametrize("dtype", get_integer_dtypes())
536-
def test_rand_data(self, dtype):
537-
n = 100
538-
upper_bound = 10 if dtype != dpnp.bool_ else 1
539-
v = numpy.random.randint(0, upper_bound, size=n, dtype=dtype)
538+
@pytest.mark.parametrize("dt", get_integer_dtypes() + [numpy.bool_])
539+
def test_rand_data(self, dt):
540+
v = generate_random_numpy_array(100, dtype=dt, low=0)
540541
iv = dpnp.array(v)
541542

542-
if numpy.issubdtype(dtype, numpy.uint64):
543-
# discussed in numpy issue 17760
544-
assert_raises(TypeError, numpy.bincount, v)
545-
assert_raises(ValueError, dpnp.bincount, iv)
546-
else:
547-
expected_hist = numpy.bincount(v)
548-
result_hist = dpnp.bincount(iv)
549-
assert_array_equal(result_hist, expected_hist)
543+
if numpy.issubdtype(dt, numpy.uint64) and numpy_version() < "2.2.4":
544+
v = v.astype(numpy.int64)
550545

551-
@pytest.mark.parametrize("dtype", get_integer_dtypes())
552-
def test_arange_data(self, dtype):
553-
v = numpy.arange(100).astype(dtype)
546+
expected_hist = numpy.bincount(v)
547+
result_hist = dpnp.bincount(iv)
548+
assert_array_equal(result_hist, expected_hist)
549+
550+
@pytest.mark.parametrize("dt", get_integer_dtypes())
551+
def test_arange_data(self, dt):
552+
v = numpy.arange(100, dtype=dt)
554553
iv = dpnp.array(v)
555554

556-
if numpy.issubdtype(dtype, numpy.uint64):
557-
assert_raises(TypeError, numpy.bincount, v)
558-
assert_raises(ValueError, dpnp.bincount, iv)
559-
else:
560-
expected_hist = numpy.bincount(v)
561-
result_hist = dpnp.bincount(iv)
562-
assert_array_equal(result_hist, expected_hist)
555+
if numpy.issubdtype(dt, numpy.uint64) and numpy_version() < "2.2.4":
556+
v = v.astype(numpy.int64)
557+
558+
expected_hist = numpy.bincount(v)
559+
result_hist = dpnp.bincount(iv)
560+
assert_array_equal(result_hist, expected_hist)
563561

564562
@pytest.mark.parametrize("xp", [numpy, dpnp])
565563
def test_negative_values(self, xp):
@@ -581,11 +579,17 @@ def test_weights_another_sycl_queue(self):
581579
dpnp.bincount(v, weights=w)
582580

583581
@pytest.mark.parametrize("xp", [numpy, dpnp])
584-
def test_weights_unsupported_dtype(self, xp):
585-
v = dpnp.arange(5)
586-
w = dpnp.arange(5, dtype=dpnp.complex64)
587-
with assert_raises(ValueError):
588-
dpnp.bincount(v, weights=w)
582+
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
583+
def test_data_unsupported_dtype(self, xp, dt):
584+
v = xp.arange(5, dtype=dt)
585+
assert_raises(TypeError, xp.bincount, v)
586+
587+
@pytest.mark.parametrize("xp", [numpy, dpnp])
588+
@pytest.mark.parametrize("dt", get_complex_dtypes())
589+
def test_weights_unsupported_dtype(self, xp, dt):
590+
v = xp.arange(5)
591+
w = xp.arange(5, dtype=dt)
592+
assert_raises((TypeError, ValueError), xp.bincount, v, weights=w)
589593

590594
@pytest.mark.parametrize(
591595
"bins_count",
@@ -606,11 +610,11 @@ def test_different_bins_amount(self, bins_count):
606610
)
607611
@pytest.mark.parametrize("minlength", [0, 1, 3, 5])
608612
def test_minlength(self, array, minlength):
609-
np_a = numpy.array(array)
610-
dpnp_a = dpnp.array(array)
613+
a = numpy.array(array)
614+
ia = dpnp.array(a)
611615

612-
expected = numpy.bincount(np_a, minlength=minlength)
613-
result = dpnp.bincount(dpnp_a, minlength=minlength)
616+
expected = numpy.bincount(a, minlength=minlength)
617+
result = dpnp.bincount(ia, minlength=minlength)
614618
assert_allclose(result, expected)
615619

616620
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@@ -639,21 +643,23 @@ def test_minlength_none(self, xp):
639643
)
640644

641645
@pytest.mark.parametrize(
642-
"array", [[1, 2, 2, 1, 2, 4]], ids=["[1, 2, 2, 1, 2, 4]"]
646+
"weights",
647+
[None, [0.3, 0.5, 0, 0.7, 1.0, -0.6], [2, 2, 2, 2, 2, 2]],
648+
ids=["None", "float_data", "int_data"],
643649
)
644650
@pytest.mark.parametrize(
645-
"weights",
646-
[None, [0.3, 0.5, 0.2, 0.7, 1.0, -0.6], [2, 2, 2, 2, 2, 2]],
647-
ids=["None", "[0.3, 0.5, 0.2, 0.7, 1., -0.6]", "[2, 2, 2, 2, 2, 2]"],
651+
"dt", get_all_dtypes(no_none=True, no_complex=True)
648652
)
649-
def test_weights(self, array, weights):
650-
np_a = numpy.array(array)
651-
np_weights = numpy.array(weights) if weights is not None else weights
652-
dpnp_a = dpnp.array(array)
653-
dpnp_weights = dpnp.array(weights) if weights is not None else weights
654-
655-
expected = numpy.bincount(np_a, weights=np_weights)
656-
result = dpnp.bincount(dpnp_a, weights=dpnp_weights)
653+
def test_weights(self, weights, dt):
654+
a = numpy.array([1, 2, 2, 1, 2, 4])
655+
ia = dpnp.array(a)
656+
w = iw = None
657+
if weights is not None:
658+
w = numpy.array(weights, dtype=dt)
659+
iw = dpnp.array(w)
660+
661+
expected = numpy.bincount(a, weights=w)
662+
result = dpnp.bincount(ia, weights=iw)
657663
assert_allclose(result, expected)
658664

659665
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)