|
19 | 19 | import dpctl
|
20 | 20 | import dpctl.tensor as dpt
|
21 | 21 | import dpctl.tensor._tensor_impl as ti
|
| 22 | +import dpctl.tensor._tensor_reductions_impl as tri |
22 | 23 |
|
23 | 24 | from ._type_utils import _to_device_supported_dtype
|
24 | 25 |
|
@@ -220,8 +221,8 @@ def sum(x, axis=None, dtype=None, keepdims=False):
|
220 | 221 | axis,
|
221 | 222 | dtype,
|
222 | 223 | keepdims,
|
223 |
| - ti._sum_over_axis, |
224 |
| - ti._sum_over_axis_dtype_supported, |
| 224 | + tri._sum_over_axis, |
| 225 | + tri._sum_over_axis_dtype_supported, |
225 | 226 | _default_reduction_dtype,
|
226 | 227 | _identity=0,
|
227 | 228 | )
|
@@ -281,8 +282,8 @@ def prod(x, axis=None, dtype=None, keepdims=False):
|
281 | 282 | axis,
|
282 | 283 | dtype,
|
283 | 284 | keepdims,
|
284 |
| - ti._prod_over_axis, |
285 |
| - ti._prod_over_axis_dtype_supported, |
| 285 | + tri._prod_over_axis, |
| 286 | + tri._prod_over_axis_dtype_supported, |
286 | 287 | _default_reduction_dtype,
|
287 | 288 | _identity=1,
|
288 | 289 | )
|
@@ -335,8 +336,8 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
|
335 | 336 | axis,
|
336 | 337 | dtype,
|
337 | 338 | keepdims,
|
338 |
| - ti._logsumexp_over_axis, |
339 |
| - lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported( |
| 339 | + tri._logsumexp_over_axis, |
| 340 | + lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( |
340 | 341 | inp_dt, res_dt
|
341 | 342 | ),
|
342 | 343 | _default_reduction_dtype_fp_types,
|
@@ -391,8 +392,8 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
|
391 | 392 | axis,
|
392 | 393 | dtype,
|
393 | 394 | keepdims,
|
394 |
| - ti._hypot_over_axis, |
395 |
| - lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported( |
| 395 | + tri._hypot_over_axis, |
| 396 | + lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( |
396 | 397 | inp_dt, res_dt
|
397 | 398 | ),
|
398 | 399 | _default_reduction_dtype_fp_types,
|
@@ -468,7 +469,7 @@ def max(x, axis=None, keepdims=False):
|
468 | 469 | entire array, a zero-dimensional array is returned. The returned
|
469 | 470 | array has the same data type as `x`.
|
470 | 471 | """
|
471 |
| - return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis) |
| 472 | + return _comparison_over_axis(x, axis, keepdims, tri._max_over_axis) |
472 | 473 |
|
473 | 474 |
|
474 | 475 | def min(x, axis=None, keepdims=False):
|
@@ -496,7 +497,7 @@ def min(x, axis=None, keepdims=False):
|
496 | 497 | entire array, a zero-dimensional array is returned. The returned
|
497 | 498 | array has the same data type as `x`.
|
498 | 499 | """
|
499 |
| - return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis) |
| 500 | + return _comparison_over_axis(x, axis, keepdims, tri._min_over_axis) |
500 | 501 |
|
501 | 502 |
|
502 | 503 | def _search_over_axis(x, axis, keepdims, _reduction_fn):
|
@@ -577,7 +578,7 @@ def argmax(x, axis=None, keepdims=False):
|
577 | 578 | zero-dimensional array is returned. The returned array has the
|
578 | 579 | default array index data type for the device of `x`.
|
579 | 580 | """
|
580 |
| - return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis) |
| 581 | + return _search_over_axis(x, axis, keepdims, tri._argmax_over_axis) |
581 | 582 |
|
582 | 583 |
|
583 | 584 | def argmin(x, axis=None, keepdims=False):
|
@@ -609,4 +610,4 @@ def argmin(x, axis=None, keepdims=False):
|
609 | 610 | zero-dimensional array is returned. The returned array has the
|
610 | 611 | default array index data type for the device of `x`.
|
611 | 612 | """
|
612 |
| - return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis) |
| 613 | + return _search_over_axis(x, axis, keepdims, tri._argmin_over_axis) |
0 commit comments