@@ -363,9 +363,9 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
363
363
364
364
Limitations
365
365
-----------
366
- Input array `a` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
367
- Parameters `out`, ` where`, and `initial` are supported only with their default values.
368
- Otherwise the function will be executed sequentially on CPU .
366
+ Input and output arrays are only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
367
+ Parameters `where`, and `initial` are supported only with their default values.
368
+ Otherwise ``NotImplementedError`` exception will be raised .
369
369
Input array data types are limited by supported DPNP :ref:`Data types`.
370
370
371
371
See Also
@@ -400,16 +400,17 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
400
400
401
401
if initial is not None :
402
402
raise NotImplementedError (
403
- "initial keyword arguemnts is only supported by its default value."
403
+ "initial keyword arguemnt is only supported by its default value."
404
404
)
405
405
elif where is not True :
406
406
raise NotImplementedError (
407
- "where keyword arguemnts is only supported by its default values ."
407
+ "where keyword arguemnt is only supported by its default value ."
408
408
)
409
409
else :
410
410
dpt_array = dpnp .get_usm_ndarray (a )
411
411
if dpt_array .size == 0 :
412
412
# TODO: get rid of this if condition when dpctl supports it
413
+ axis = (axis ,) if isinstance (axis , int ) else axis
413
414
for i in range (a .ndim ):
414
415
if a .shape [i ] == 0 :
415
416
if i not in axis :
@@ -431,21 +432,17 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
431
432
raise ValueError (
432
433
f"Output array of shape { result .shape } is needed, got { out .shape } ."
433
434
)
434
- elif out .dtype != result .dtype :
435
- raise ValueError (
436
- f"Output array of type { result .dtype } is needed, got { out .dtype } ."
437
- )
438
435
elif not isinstance (out , dpnp_array ):
439
436
if isinstance (out , dpt .usm_ndarray ):
440
- out = dpnp . array (out )
437
+ out = dpnp_array . _create_from_usm_ndarray (out )
441
438
else :
442
- raise ValueError (
443
- "An array must be any of supported type, but got {}" .format (
439
+ raise TypeError (
440
+ "Output array must be any of supported type, but got {}" .format (
444
441
type (out )
445
442
)
446
443
)
447
444
448
- dpnp .copyto (out , result )
445
+ dpnp .copyto (out , result , casting = "safe" )
449
446
450
447
return out
451
448
@@ -610,9 +607,9 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
610
607
611
608
Limitations
612
609
-----------
613
- Input array `a` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
614
- Parameters `out`, ` where`, and `initial` are supported only with their default values.
615
- Otherwise the function will be executed sequentially on CPU .
610
+ Input and output arrays are only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
611
+ Parameters `where`, and `initial` are supported only with their default values.
612
+ Otherwise ``NotImplementedError`` exception will be raised .
616
613
Input array data types are limited by supported DPNP :ref:`Data types`.
617
614
618
615
See Also
@@ -647,11 +644,11 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
647
644
648
645
if initial is not None :
649
646
raise NotImplementedError (
650
- "initial keyword arguemnts is only supported by its default value."
647
+ "initial keyword arguemnt is only supported by its default value."
651
648
)
652
649
elif where is not True :
653
650
raise NotImplementedError (
654
- "where keyword arguemnts is only supported by its default values."
651
+ "where keyword arguemnt is only supported by its default values."
655
652
)
656
653
else :
657
654
dpt_array = dpnp .get_usm_ndarray (a )
@@ -678,21 +675,17 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
678
675
raise ValueError (
679
676
f"Output array of shape { result .shape } is needed, got { out .shape } ."
680
677
)
681
- elif out .dtype != result .dtype :
682
- raise ValueError (
683
- f"Output array of type { result .dtype } is needed, got { out .dtype } ."
684
- )
685
678
elif not isinstance (out , dpnp_array ):
686
679
if isinstance (out , dpt .usm_ndarray ):
687
- out = dpnp . array (out )
680
+ out = dpnp_array . _create_from_usm_ndarray (out )
688
681
else :
689
- raise ValueError (
690
- "An array must be any of supported type, but got {}" .format (
682
+ raise TypeError (
683
+ "Output array must be any of supported type, but got {}" .format (
691
684
type (out )
692
685
)
693
686
)
694
687
695
- dpnp .copyto (out , result )
688
+ dpnp .copyto (out , result , casting = "safe" )
696
689
697
690
return out
698
691
0 commit comments