@@ -298,16 +298,12 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
298
298
else :
299
299
ntype = weights .dtype
300
300
301
- # We set a block size, as this allows us to iterate over chunks when
302
- # computing histograms, to minimize memory usage.
303
- block_size = 65536
304
-
305
301
# The fast path uses bincount, but that only works for certain types
306
302
# of weight
307
303
# simple_weights = (
308
304
# weights is None or
309
- # np .can_cast(weights.dtype, np .double) or
310
- # np .can_cast(weights.dtype, complex)
305
+ # dpnp .can_cast(weights.dtype, dpnp .double) or
306
+ # dpnp .can_cast(weights.dtype, complex)
311
307
# )
312
308
# TODO: implement a fast path
313
309
simple_weights = False
@@ -317,24 +313,19 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
317
313
pass
318
314
else :
319
315
# Compute via cumulative histogram
320
- cum_n = dpnp .zeros_like (bin_edges , dtype = ntype )
321
316
if weights is None :
322
- for i in _range (0 , len (a ), block_size ):
323
- sa = dpnp .sort (a [i : i + block_size ])
324
- cum_n += _search_sorted_inclusive (sa , bin_edges )
317
+ sa = dpnp .sort (a )
318
+ cum_n = _search_sorted_inclusive (sa , bin_edges )
325
319
else :
326
320
zero = dpnp .zeros (
327
321
1 , dtype = ntype , sycl_queue = a .sycl_queue , usm_type = a .usm_type
328
322
)
329
- for i in _range (0 , len (a ), block_size ):
330
- tmp_a = a [i : i + block_size ]
331
- tmp_w = weights [i : i + block_size ]
332
- sorting_index = dpnp .argsort (tmp_a )
333
- sa = tmp_a [sorting_index ]
334
- sw = tmp_w [sorting_index ]
335
- cw = dpnp .concatenate ((zero , sw .cumsum (dtype = ntype )))
336
- bin_index = _search_sorted_inclusive (sa , bin_edges )
337
- cum_n += cw [bin_index ]
323
+ sorting_index = dpnp .argsort (a )
324
+ sa = a [sorting_index ]
325
+ sw = weights [sorting_index ]
326
+ cw = dpnp .concatenate ((zero , sw .cumsum (dtype = ntype )))
327
+ bin_index = _search_sorted_inclusive (sa , bin_edges )
328
+ cum_n = cw [bin_index ]
338
329
339
330
n = dpnp .diff (cum_n )
340
331
0 commit comments