@@ -44,12 +44,15 @@ def _reduction_over_axis(
44
44
nd = x .ndim
45
45
if axis is None :
46
46
axis = tuple (range (nd ))
47
- if not isinstance (axis , (tuple , list )):
48
- axis = (axis ,)
49
- axis = normalize_axis_tuple (axis , nd , "axis" )
47
+ perm = list (axis )
48
+ arr = x
49
+ else :
50
+ if not isinstance (axis , (tuple , list )):
51
+ axis = (axis ,)
52
+ axis = normalize_axis_tuple (axis , nd , "axis" )
53
+ perm = [i for i in range (nd ) if i not in axis ] + list (axis )
54
+ arr = dpt .permute_dims (x , perm )
50
55
red_nd = len (axis )
51
- perm = [i for i in range (nd ) if i not in axis ] + list (axis )
52
- arr = dpt .permute_dims (x , perm )
53
56
res_shape = arr .shape [: nd - red_nd ]
54
57
q = x .sycl_queue
55
58
inp_dt = x .dtype
@@ -89,7 +92,7 @@ def _reduction_over_axis(
89
92
)
90
93
if res_dt != out .dtype :
91
94
raise ValueError (
92
- f"Output array of type { res_dt } is needed, " f" got { out .dtype } "
95
+ f"Output array of type { res_dt } is needed, got { out .dtype } "
93
96
)
94
97
if dpctl .utils .get_execution_queue ((q , out .sycl_queue )) is None :
95
98
raise ExecutionPlacementError (
@@ -441,14 +444,17 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
441
444
nd = x .ndim
442
445
if axis is None :
443
446
axis = tuple (range (nd ))
444
- if not isinstance (axis , (tuple , list )):
445
- axis = (axis ,)
446
- if any ([x .shape [i ] == 0 for i in axis ]):
447
- raise ValueError ("reduction cannot be performed over zero-size axes" )
448
- axis = normalize_axis_tuple (axis , nd , "axis" )
447
+ perm = list (axis )
448
+ x_tmp = x
449
+ else :
450
+ if not isinstance (axis , (tuple , list )):
451
+ axis = (axis ,)
452
+ axis = normalize_axis_tuple (axis , nd , "axis" )
453
+ perm = [i for i in range (nd ) if i not in axis ] + list (axis )
454
+ x_tmp = dpt .permute_dims (x , perm )
449
455
red_nd = len (axis )
450
- perm = [ i for i in range (nd ) if i not in axis ] + list ( axis )
451
- x_tmp = dpt . permute_dims ( x , perm )
456
+ if any ([ x_tmp . shape [ i ] == 0 for i in range (- red_nd , 0 )]):
457
+ raise ValueError ( "reduction cannot be performed over zero-size axes" )
452
458
res_shape = x_tmp .shape [: nd - red_nd ]
453
459
exec_q = x .sycl_queue
454
460
res_dt = x .dtype
@@ -476,7 +482,7 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
476
482
)
477
483
if res_dt != out .dtype :
478
484
raise ValueError (
479
- f"Output array of type { res_dt } is needed, " f" got { out .dtype } "
485
+ f"Output array of type { res_dt } is needed, got { out .dtype } "
480
486
)
481
487
if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
482
488
raise ExecutionPlacementError (
@@ -602,18 +608,22 @@ def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
602
608
nd = x .ndim
603
609
if axis is None :
604
610
axis = tuple (range (nd ))
605
- elif isinstance (axis , int ):
606
- axis = ( axis ,)
611
+ perm = list (axis )
612
+ x_tmp = x
607
613
else :
608
- raise TypeError (
609
- f"`axis` argument expected `int` or `None`, got { type (axis )} "
610
- )
614
+ if isinstance (axis , int ):
615
+ axis = (axis ,)
616
+ else :
617
+ raise TypeError (
618
+ f"`axis` argument expected `int` or `None`, got { type (axis )} "
619
+ )
620
+ axis = normalize_axis_tuple (axis , nd , "axis" )
621
+ perm = [i for i in range (nd ) if i not in axis ] + list (axis )
622
+ x_tmp = dpt .permute_dims (x , perm )
611
623
axis = normalize_axis_tuple (axis , nd , "axis" )
612
- if any ([x .shape [i ] == 0 for i in axis ]):
613
- raise ValueError ("reduction cannot be performed over zero-size axes" )
614
624
red_nd = len (axis )
615
- perm = [ i for i in range (nd ) if i not in axis ] + list ( axis )
616
- x_tmp = dpt . permute_dims ( x , perm )
625
+ if any ([ x_tmp . shape [ i ] == 0 for i in range (- red_nd , 0 )]):
626
+ raise ValueError ( "reduction cannot be performed over zero-size axes" )
617
627
res_shape = x_tmp .shape [: nd - red_nd ]
618
628
exec_q = x .sycl_queue
619
629
res_dt = ti .default_device_index_type (exec_q .sycl_device )
@@ -641,7 +651,7 @@ def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
641
651
)
642
652
if res_dt != out .dtype :
643
653
raise ValueError (
644
- f"Output array of type { res_dt } is needed, " f" got { out .dtype } "
654
+ f"Output array of type { res_dt } is needed, got { out .dtype } "
645
655
)
646
656
if dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue )) is None :
647
657
raise ExecutionPlacementError (
0 commit comments