|
47 | 47 | import dpnp
|
48 | 48 |
|
49 | 49 | # pylint: disable=no-name-in-module
|
50 |
| -from .dpnp_algo import ( |
51 |
| - dpnp_correlate, |
52 |
| -) |
| 50 | +from .dpnp_algo import dpnp_correlate |
53 | 51 | from .dpnp_array import dpnp_array
|
54 |
| -from .dpnp_utils import ( |
55 |
| - call_origin, |
56 |
| - get_usm_allocations, |
57 |
| -) |
| 52 | +from .dpnp_utils import call_origin, get_usm_allocations |
58 | 53 | from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
|
59 |
| -from .dpnp_utils.dpnp_utils_statistics import ( |
60 |
| - dpnp_cov, |
61 |
| -) |
| 54 | +from .dpnp_utils.dpnp_utils_statistics import dpnp_cov |
62 | 55 |
|
63 | 56 | __all__ = [
|
64 | 57 | "amax",
|
@@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False):
|
276 | 269 | """
|
277 | 270 |
|
278 | 271 | dpnp.check_supported_arrays_type(a)
|
| 272 | + usm_type, exec_q = get_usm_allocations([a, weights]) |
| 273 | + |
279 | 274 | if weights is None:
|
280 | 275 | avg = dpnp.mean(a, axis=axis, keepdims=keepdims)
|
281 | 276 | scl = dpnp.asanyarray(
|
282 | 277 | avg.dtype.type(a.size / avg.size),
|
283 |
| - usm_type=a.usm_type, |
284 |
| - sycl_queue=a.sycl_queue, |
| 278 | + usm_type=usm_type, |
| 279 | + sycl_queue=exec_q, |
285 | 280 | )
|
286 | 281 | else:
|
287 |
| - if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)): |
288 |
| - wgt = dpnp.asanyarray( |
289 |
| - weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue |
| 282 | + if not dpnp.is_supported_array_type(weights): |
| 283 | + weights = dpnp.asarray( |
| 284 | + weights, usm_type=usm_type, sycl_queue=exec_q |
290 | 285 | )
|
291 |
| - else: |
292 |
| - get_usm_allocations([a, weights]) |
293 |
| - wgt = weights |
294 | 286 |
|
295 |
| - if not dpnp.issubdtype(a.dtype, dpnp.inexact): |
| 287 | + a_dtype = a.dtype |
| 288 | + if not dpnp.issubdtype(a_dtype, dpnp.inexact): |
296 | 289 | default_dtype = dpnp.default_float_type(a.device)
|
297 |
| - result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype) |
| 290 | + res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype) |
298 | 291 | else:
|
299 |
| - result_dtype = dpnp.result_type(a.dtype, wgt.dtype) |
| 292 | + res_dtype = dpnp.result_type(a_dtype, weights.dtype) |
300 | 293 |
|
301 | 294 | # Sanity checks
|
302 |
| - if a.shape != wgt.shape: |
| 295 | + wgt_shape = weights.shape |
| 296 | + a_shape = a.shape |
| 297 | + if a_shape != wgt_shape: |
303 | 298 | if axis is None:
|
304 | 299 | raise TypeError(
|
305 | 300 | "Axis must be specified when shapes of input array and "
|
306 | 301 | "weights differ."
|
307 | 302 | )
|
308 |
| - if wgt.ndim != 1: |
| 303 | + if weights.ndim != 1: |
309 | 304 | raise TypeError(
|
310 | 305 | "1D weights expected when shapes of input array and "
|
311 | 306 | "weights differ."
|
312 | 307 | )
|
313 |
| - if wgt.shape[0] != a.shape[axis]: |
| 308 | + if wgt_shape[0] != a_shape[axis]: |
314 | 309 | raise ValueError(
|
315 | 310 | "Length of weights not compatible with specified axis."
|
316 | 311 | )
|
317 | 312 |
|
318 |
| - # setup wgt to broadcast along axis |
319 |
| - wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape) |
320 |
| - wgt = wgt.swapaxes(-1, axis) |
| 313 | + # setup weights to broadcast along axis |
| 314 | + weights = dpnp.broadcast_to( |
| 315 | + weights, (a.ndim - 1) * (1,) + wgt_shape |
| 316 | + ) |
| 317 | + weights = weights.swapaxes(-1, axis) |
321 | 318 |
|
322 |
| - scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims) |
| 319 | + scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims) |
323 | 320 | if dpnp.any(scl == 0.0):
|
324 | 321 | raise ZeroDivisionError("Weights sum to zero, can't be normalized")
|
325 | 322 |
|
326 |
| - # result_datatype |
327 |
| - avg = ( |
328 |
| - dpnp.multiply(a, wgt).sum( |
329 |
| - axis=axis, dtype=result_dtype, keepdims=keepdims |
330 |
| - ) |
331 |
| - / scl |
| 323 | + avg = dpnp.multiply(a, weights).sum( |
| 324 | + axis=axis, dtype=res_dtype, keepdims=keepdims |
332 | 325 | )
|
| 326 | + avg /= scl |
333 | 327 |
|
334 | 328 | if returned:
|
335 | 329 | if scl.shape != avg.shape:
|
@@ -556,7 +550,7 @@ def cov(
|
556 | 550 |
|
557 | 551 | """
|
558 | 552 |
|
559 |
| - if not isinstance(m, (dpnp_array, dpt.usm_ndarray)): |
| 553 | + if not dpnp.is_supported_array_type(m): |
560 | 554 | pass
|
561 | 555 | elif m.ndim > 2:
|
562 | 556 | pass
|
|
0 commit comments