Skip to content

Commit 2036b38

Browse files
committed
reductions now well-behaved for size-zero arrays
- comparison and search reductions will throw an error in this case - slips in change to align sum signature with array API spec
1 parent 1217550 commit 2036b38

File tree

1 file changed

+51
-65
lines changed

1 file changed

+51
-65
lines changed

dpctl/tensor/_reduction.py

Lines changed: 51 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _default_reduction_dtype(inp_dt, q):
5252
return res_dt
5353

5454

55-
def sum(arr, axis=None, dtype=None, keepdims=False):
55+
def sum(x, axis=None, dtype=None, keepdims=False):
5656
"""sum(x, axis=None, dtype=None, keepdims=False)
5757
5858
Calculates the sum of the input array `x`.
@@ -101,28 +101,28 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
101101
array has the data type as described in the `dtype` parameter
102102
description above.
103103
"""
104-
if not isinstance(arr, dpt.usm_ndarray):
105-
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(arr)}")
106-
nd = arr.ndim
104+
if not isinstance(x, dpt.usm_ndarray):
105+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
106+
nd = x.ndim
107107
if axis is None:
108108
axis = tuple(range(nd))
109109
if not isinstance(axis, (tuple, list)):
110110
axis = (axis,)
111111
axis = normalize_axis_tuple(axis, nd, "axis")
112112
red_nd = len(axis)
113113
perm = [i for i in range(nd) if i not in axis] + list(axis)
114-
arr2 = dpt.permute_dims(arr, perm)
114+
arr2 = dpt.permute_dims(x, perm)
115115
res_shape = arr2.shape[: nd - red_nd]
116-
q = arr.sycl_queue
117-
inp_dt = arr.dtype
116+
q = x.sycl_queue
117+
inp_dt = x.dtype
118118
if dtype is None:
119119
res_dt = _default_reduction_dtype(inp_dt, q)
120120
else:
121121
res_dt = dpt.dtype(dtype)
122122
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
123123

124-
res_usm_type = arr.usm_type
125-
if arr.size == 0:
124+
res_usm_type = x.usm_type
125+
if x.size == 0:
126126
if keepdims:
127127
res_shape = res_shape + (1,) * red_nd
128128
inv_perm = sorted(range(nd), key=lambda d: perm[d])
@@ -131,7 +131,7 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
131131
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
132132
)
133133
if red_nd == 0:
134-
return dpt.astype(arr, res_dt, copy=False)
134+
return dpt.astype(x, res_dt, copy=False)
135135

136136
host_tasks_list = []
137137
if ti._sum_over_axis_dtype_supported(inp_dt, res_dt, res_usm_type, q):
@@ -173,43 +173,35 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
173173
return res
174174

175175

176-
def _same_dtype_reduction(x, axis, keepdims, func):
176+
def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
177177
if not isinstance(x, dpt.usm_ndarray):
178178
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
179179

180180
nd = x.ndim
181181
if axis is None:
182-
red_nd = nd
183-
# case of a scalar
184-
if red_nd == 0:
185-
return dpt.copy(x)
186-
x_tmp = x
187-
res_shape = tuple()
188-
perm = list(range(nd))
189-
else:
190-
if not isinstance(axis, (tuple, list)):
191-
axis = (axis,)
192-
axis = normalize_axis_tuple(axis, nd, "axis")
193-
194-
red_nd = len(axis)
195-
# check for axis=()
196-
if red_nd == 0:
197-
return dpt.copy(x)
198-
perm = [i for i in range(nd) if i not in axis] + list(axis)
199-
x_tmp = dpt.permute_dims(x, perm)
200-
res_shape = x_tmp.shape[: nd - red_nd]
201-
182+
axis = tuple(range(nd))
183+
if not isinstance(axis, (tuple, list)):
184+
axis = (axis,)
185+
axis = normalize_axis_tuple(axis, nd, "axis")
186+
red_nd = len(axis)
187+
perm = [i for i in range(nd) if i not in axis] + list(axis)
188+
x_tmp = dpt.permute_dims(x, perm)
189+
res_shape = x_tmp.shape[: nd - red_nd]
202190
exec_q = x.sycl_queue
191+
res_dt = x.dtype
203192
res_usm_type = x.usm_type
204-
res_dtype = x.dtype
193+
if x.size == 0:
194+
raise ValueError("reduction does not support zero-size arrays")
195+
if red_nd == 0:
196+
return x
205197

206198
res = dpt.empty(
207199
res_shape,
208-
dtype=res_dtype,
200+
dtype=res_dt,
209201
usm_type=res_usm_type,
210202
sycl_queue=exec_q,
211203
)
212-
hev, _ = func(
204+
hev, _ = _reduction_fn(
213205
src=x_tmp,
214206
trailing_dims_to_reduce=red_nd,
215207
dst=res,
@@ -225,54 +217,48 @@ def _same_dtype_reduction(x, axis, keepdims, func):
225217

226218

227219
def max(x, axis=None, keepdims=False):
228-
return _same_dtype_reduction(x, axis, keepdims, ti._max_over_axis)
220+
return _comparison_over_axis(x, axis, keepdims, ti._max_over_axis)
229221

230222

231223
def min(x, axis=None, keepdims=False):
232-
return _same_dtype_reduction(x, axis, keepdims, ti._min_over_axis)
224+
return _comparison_over_axis(x, axis, keepdims, ti._min_over_axis)
233225

234226

235-
def _argmax_argmin_reduction(x, axis, keepdims, func):
227+
def _search_over_axis(x, axis, keepdims, _reduction_fn):
236228
if not isinstance(x, dpt.usm_ndarray):
237229
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
238230

239231
nd = x.ndim
240232
if axis is None:
241-
red_nd = nd
242-
# case of a scalar
243-
if red_nd == 0:
244-
return dpt.zeros(
245-
(), dtype="i8", usm_type=x.usm_type, sycl_queue=x.sycl_queue
246-
)
247-
x_tmp = x
248-
res_shape = tuple()
249-
perm = list(range(nd))
233+
axis = tuple(range(nd))
234+
elif isinstance(axis, int):
235+
axis = (axis,)
250236
else:
251-
if not isinstance(axis, (tuple, list)):
252-
axis = (axis,)
253-
axis = normalize_axis_tuple(axis, nd, "axis")
254-
255-
red_nd = len(axis)
256-
# check for axis=()
257-
if red_nd == 0:
258-
return dpt.zeros(
259-
(), dtype="i8", usm_type=x.usm_type, sycl_queue=x.sycl_queue
260-
)
261-
perm = [i for i in range(nd) if i not in axis] + list(axis)
262-
x_tmp = dpt.permute_dims(x, perm)
263-
res_shape = x_tmp.shape[: nd - red_nd]
264-
237+
raise TypeError(
238+
f"`axis` argument expected `int` or `None`, got {type(axis)}"
239+
)
240+
axis = normalize_axis_tuple(axis, nd, "axis")
241+
red_nd = len(axis)
242+
perm = [i for i in range(nd) if i not in axis] + list(axis)
243+
x_tmp = dpt.permute_dims(x, perm)
244+
res_shape = x_tmp.shape[: nd - red_nd]
265245
exec_q = x.sycl_queue
246+
res_dt = ti.default_device_index_type(exec_q.sycl_device)
266247
res_usm_type = x.usm_type
267-
res_dtype = dpt.int64
248+
if x.size == 0:
249+
raise ValueError("reduction does not support zero-size arrays")
250+
if red_nd == 0:
251+
return dpt.zeros(
252+
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q
253+
)
268254

269255
res = dpt.empty(
270256
res_shape,
271-
dtype=res_dtype,
257+
dtype=res_dt,
272258
usm_type=res_usm_type,
273259
sycl_queue=exec_q,
274260
)
275-
hev, _ = func(
261+
hev, _ = _reduction_fn(
276262
src=x_tmp,
277263
trailing_dims_to_reduce=red_nd,
278264
dst=res,
@@ -288,8 +274,8 @@ def _argmax_argmin_reduction(x, axis, keepdims, func):
288274

289275

290276
def argmax(x, axis=None, keepdims=False):
291-
return _argmax_argmin_reduction(x, axis, keepdims, ti._argmax_over_axis)
277+
return _search_over_axis(x, axis, keepdims, ti._argmax_over_axis)
292278

293279

294280
def argmin(x, axis=None, keepdims=False):
295-
return _argmax_argmin_reduction(x, axis, keepdims, ti._argmin_over_axis)
281+
return _search_over_axis(x, axis, keepdims, ti._argmin_over_axis)

0 commit comments

Comments
 (0)