Skip to content

Commit 233e4fc

Browse files
committed
fix a few issues
1 parent f6993d0 commit 233e4fc

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ env:
1313
# TODO: to add test_arraymanipulation.py back to the scope once crash on Windows is gone
1414
TEST_SCOPE: >-
1515
test_arraycreation.py
16+
test_amin_amax.py
1617
test_dot.py
1718
test_dparray.py
1819
test_copy.py

dpnp/dpnp_iface_statistics.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
363363
364364
Limitations
365365
-----------
366-
Input array `a` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
367-
Parameters `out`, `where`, and `initial` are supported only with their default values.
368-
Otherwise the function will be executed sequentially on CPU.
366+
Input and output arrays are only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
367+
Parameters `where`, and `initial` are supported only with their default values.
368+
Otherwise ``NotImplementedError`` exception will be raised.
369369
Input array data types are limited by supported DPNP :ref:`Data types`.
370370
371371
See Also
@@ -400,16 +400,17 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
400400

401401
if initial is not None:
402402
raise NotImplementedError(
403-
"initial keyword arguemnts is only supported by its default value."
403+
"initial keyword arguemnt is only supported by its default value."
404404
)
405405
elif where is not True:
406406
raise NotImplementedError(
407-
"where keyword arguemnts is only supported by its default values."
407+
"where keyword arguemnt is only supported by its default value."
408408
)
409409
else:
410410
dpt_array = dpnp.get_usm_ndarray(a)
411411
if dpt_array.size == 0:
412412
# TODO: get rid of this if condition when dpctl supports it
413+
axis = (axis,) if isinstance(axis, int) else axis
413414
for i in range(a.ndim):
414415
if a.shape[i] == 0:
415416
if i not in axis:
@@ -431,21 +432,17 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
431432
raise ValueError(
432433
f"Output array of shape {result.shape} is needed, got {out.shape}."
433434
)
434-
elif out.dtype != result.dtype:
435-
raise ValueError(
436-
f"Output array of type {result.dtype} is needed, got {out.dtype}."
437-
)
438435
elif not isinstance(out, dpnp_array):
439436
if isinstance(out, dpt.usm_ndarray):
440-
out = dpnp.array(out)
437+
out = dpnp_array._create_from_usm_ndarray(out)
441438
else:
442-
raise ValueError(
443-
"An array must be any of supported type, but got {}".format(
439+
raise TypeError(
440+
"Output array must be any of supported type, but got {}".format(
444441
type(out)
445442
)
446443
)
447444

448-
dpnp.copyto(out, result)
445+
dpnp.copyto(out, result, casting="safe")
449446

450447
return out
451448

@@ -610,9 +607,9 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
610607
611608
Limitations
612609
-----------
613-
Input array `a` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
614-
Parameters `out`, `where`, and `initial` are supported only with their default values.
615-
Otherwise the function will be executed sequentially on CPU.
610+
Input and output arrays are only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
611+
Parameters `where`, and `initial` are supported only with their default values.
612+
Otherwise ``NotImplementedError`` exception will be raised.
616613
Input array data types are limited by supported DPNP :ref:`Data types`.
617614
618615
See Also
@@ -647,11 +644,11 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
647644

648645
if initial is not None:
649646
raise NotImplementedError(
650-
"initial keyword arguemnts is only supported by its default value."
647+
"initial keyword arguemnt is only supported by its default value."
651648
)
652649
elif where is not True:
653650
raise NotImplementedError(
654-
"where keyword arguemnts is only supported by its default values."
651+
"where keyword arguemnt is only supported by its default values."
655652
)
656653
else:
657654
dpt_array = dpnp.get_usm_ndarray(a)
@@ -678,21 +675,17 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
678675
raise ValueError(
679676
f"Output array of shape {result.shape} is needed, got {out.shape}."
680677
)
681-
elif out.dtype != result.dtype:
682-
raise ValueError(
683-
f"Output array of type {result.dtype} is needed, got {out.dtype}."
684-
)
685678
elif not isinstance(out, dpnp_array):
686679
if isinstance(out, dpt.usm_ndarray):
687-
out = dpnp.array(out)
680+
out = dpnp_array._create_from_usm_ndarray(out)
688681
else:
689-
raise ValueError(
690-
"An array must be any of supported type, but got {}".format(
682+
raise TypeError(
683+
"Output array must be any of supported type, but got {}".format(
691684
type(out)
692685
)
693686
)
694687

695-
dpnp.copyto(out, result)
688+
dpnp.copyto(out, result, casting="safe")
696689

697690
return out
698691

0 commit comments

Comments
 (0)