@@ -59,6 +59,7 @@ def _ravel_check_a_and_weights(a, weights):
59
59
60
60
# ensure that `a` array has supported type
61
61
dpnp .check_supported_arrays_type (a )
62
+ usm_type = a .usm_type
62
63
63
64
# ensure that the array is a "subtractable" dtype
64
65
if a .dtype == dpnp .bool :
@@ -73,6 +74,7 @@ def _ravel_check_a_and_weights(a, weights):
73
74
if weights is not None :
74
75
# check that `weights` array has supported type
75
76
dpnp .check_supported_arrays_type (weights )
77
+ usm_type = dpu .get_coerced_usm_type ([usm_type , weights .usm_type ])
76
78
77
79
# check that arrays have the same allocation queue
78
80
if dpu .get_execution_queue ([a .sycl_queue , weights .sycl_queue ]) is None :
@@ -84,7 +86,7 @@ def _ravel_check_a_and_weights(a, weights):
84
86
raise ValueError ("weights should have the same shape as a." )
85
87
weights = weights .ravel ()
86
88
a = a .ravel ()
87
- return a , weights
89
+ return a , weights , usm_type
88
90
89
91
90
92
def _get_outer_edges (a , range ):
@@ -124,12 +126,13 @@ def _get_outer_edges(a, range):
124
126
return first_edge , last_edge
125
127
126
128
127
- def _get_bin_edges (a , bins , range ):
129
+ def _get_bin_edges (a , bins , range , usm_type ):
128
130
"""Computes the bins used internally by `histogram`."""
129
131
130
132
# parse the overloaded bins argument
131
133
n_equal_bins = None
132
134
bin_edges = None
135
+ sycl_queue = a .sycl_queue
133
136
134
137
if isinstance (bins , str ):
135
138
raise NotImplementedError ("only integer and array bins are implemented" )
@@ -154,7 +157,7 @@ def _get_bin_edges(a, bins, range):
154
157
bin_edges = bins
155
158
else :
156
159
bin_edges = dpnp .asarray (
157
- bins , sycl_queue = a . sycl_queue , usm_type = a . usm_type
160
+ bins , sycl_queue = sycl_queue , usm_type = usm_type
158
161
)
159
162
160
163
if dpnp .any (bin_edges [:- 1 ] > bin_edges [1 :]):
@@ -172,7 +175,7 @@ def _get_bin_edges(a, bins, range):
172
175
bin_type = dpnp .result_type (first_edge , last_edge , a )
173
176
if dpnp .issubdtype (bin_type , dpnp .integer ):
174
177
bin_type = dpnp .result_type (
175
- bin_type , dpnp .default_float_type (sycl_queue = a . sycl_queue ), a
178
+ bin_type , dpnp .default_float_type (sycl_queue = sycl_queue ), a
176
179
)
177
180
178
181
# bin edges must be computed
@@ -182,8 +185,8 @@ def _get_bin_edges(a, bins, range):
182
185
n_equal_bins + 1 ,
183
186
endpoint = True ,
184
187
dtype = bin_type ,
185
- sycl_queue = a . sycl_queue ,
186
- usm_type = a . usm_type ,
188
+ sycl_queue = sycl_queue ,
189
+ usm_type = usm_type ,
187
190
)
188
191
return bin_edges , (first_edge , last_edge , n_equal_bins )
189
192
return bin_edges , None
@@ -285,9 +288,9 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
285
288
286
289
"""
287
290
288
- a , weights = _ravel_check_a_and_weights (a , weights )
291
+ a , weights , usm_type = _ravel_check_a_and_weights (a , weights )
289
292
290
- bin_edges , uniform_bins = _get_bin_edges (a , bins , range )
293
+ bin_edges , uniform_bins = _get_bin_edges (a , bins , range , usm_type )
291
294
292
295
# Histogram is an integer or a float array depending on the weights.
293
296
if weights is None :
@@ -320,7 +323,9 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
320
323
sa = dpnp .sort (a [i : i + block_size ])
321
324
cum_n += _search_sorted_inclusive (sa , bin_edges )
322
325
else :
323
- zero = dpnp .zeros (1 , dtype = ntype )
326
+ zero = dpnp .zeros (
327
+ 1 , dtype = ntype , sycl_queue = a .sycl_queue , usm_type = a .usm_type
328
+ )
324
329
for i in _range (0 , len (a ), block_size ):
325
330
tmp_a = a [i : i + block_size ]
326
331
tmp_w = weights [i : i + block_size ]
0 commit comments