Skip to content

Commit 852f4b1

Browse files
authored
dpctl.tensor.where output preserves memory order of inputs (#1342)
* Where result now keeps order of operands - Now when operands are cast, stride simplification can still be performed on non-C contiguous inputs - Implements _empty_like_triple_orderK to allocate output of where * Adds test for correct order="K" behavior in where * Adjusted logic in _empty_like_triple_orderK - Now calls _empty_like_pair_orderK when two arrays are of equal shape and larger than the third * Changes to order "K" stride sorting - Dimensions of size 1 are effectively disregarded in sorting * Fixed typo in _empty_like_orderK
1 parent b008b8b commit 852f4b1

File tree

3 files changed

+126
-25
lines changed

3 files changed

+126
-25
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
351351
)
352352
st = list(X.strides)
353353
perm = sorted(
354-
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
354+
range(X.ndim),
355+
key=lambda d: builtins.abs(st[d]) if X.shape[d] > 1 else 0,
356+
reverse=True,
355357
)
356358
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
357359
sh = X.shape
@@ -395,9 +397,14 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
395397
max_ndim = max(nd1, nd2)
396398
st1 += [0] * (max_ndim - len(st1))
397399
st2 += [0] * (max_ndim - len(st2))
400+
sh1 = list(X1.shape) + [0] * (max_ndim - nd1)
401+
sh2 = list(X2.shape) + [0] * (max_ndim - nd2)
398402
perm = sorted(
399403
range(max_ndim),
400-
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
404+
key=lambda d: (
405+
builtins.abs(st1[d]) if sh1[d] > 1 else 0,
406+
builtins.abs(st2[d]) if sh2[d] > 1 else 0,
407+
),
401408
reverse=True,
402409
)
403410
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
@@ -417,6 +424,74 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
417424
return dpt.permute_dims(R, inv_perm)
418425

419426

427+
def _empty_like_triple_orderK(X1, X2, X3, dt, res_shape, usm_type, dev):
428+
if not isinstance(X1, dpt.usm_ndarray):
429+
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
430+
if not isinstance(X2, dpt.usm_ndarray):
431+
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
432+
if not isinstance(X3, dpt.usm_ndarray):
433+
raise TypeError(f"Expected usm_ndarray, got {type(X3)}")
434+
nd1 = X1.ndim
435+
nd2 = X2.ndim
436+
nd3 = X3.ndim
437+
if X1.shape == res_shape and X2.shape == res_shape and len(res_shape) > nd3:
438+
return _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev)
439+
elif (
440+
X2.shape == res_shape and X3.shape == res_shape and len(res_shape) > nd1
441+
):
442+
return _empty_like_pair_orderK(X2, X3, dt, res_shape, usm_type, dev)
443+
elif (
444+
X1.shape == res_shape and X3.shape == res_shape and len(res_shape) > nd2
445+
):
446+
return _empty_like_pair_orderK(X1, X3, dt, res_shape, usm_type, dev)
447+
fl1 = X1.flags
448+
fl2 = X2.flags
449+
fl3 = X3.flags
450+
if fl1["C"] or fl2["C"] or fl3["C"]:
451+
return dpt.empty(
452+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
453+
)
454+
if fl1["F"] and fl2["F"] and fl3["F"]:
455+
return dpt.empty(
456+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
457+
)
458+
st1 = list(X1.strides)
459+
st2 = list(X2.strides)
460+
st3 = list(X3.strides)
461+
max_ndim = max(nd1, nd2, nd3)
462+
st1 += [0] * (max_ndim - len(st1))
463+
st2 += [0] * (max_ndim - len(st2))
464+
st3 += [0] * (max_ndim - len(st3))
465+
sh1 = list(X1.shape) + [0] * (max_ndim - nd1)
466+
sh2 = list(X2.shape) + [0] * (max_ndim - nd2)
467+
sh3 = list(X3.shape) + [0] * (max_ndim - nd3)
468+
perm = sorted(
469+
range(max_ndim),
470+
key=lambda d: (
471+
builtins.abs(st1[d]) if sh1[d] > 1 else 0,
472+
builtins.abs(st2[d]) if sh2[d] > 1 else 0,
473+
builtins.abs(st3[d]) if sh3[d] > 1 else 0,
474+
),
475+
reverse=True,
476+
)
477+
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
478+
st1_sorted = [st1[i] for i in perm]
479+
st2_sorted = [st2[i] for i in perm]
480+
st3_sorted = [st3[i] for i in perm]
481+
sh = res_shape
482+
sh_sorted = tuple(sh[i] for i in perm)
483+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
484+
if max(min(st1_sorted), min(st2_sorted), min(st3_sorted)) < 0:
485+
sl = tuple(
486+
slice(None, None, -1)
487+
if (st1_sorted[i] < 0 and st2_sorted[i] < 0 and st3_sorted[i] < 0)
488+
else slice(None, None, None)
489+
for i in range(nd1)
490+
)
491+
R = R[sl]
492+
return dpt.permute_dims(R, inv_perm)
493+
494+
420495
def copy(usm_ary, order="K"):
421496
"""copy(ary, order="K")
422497

dpctl/tensor/_search_functions.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl.tensor._tensor_impl as ti
2020
from dpctl.tensor._manipulation_functions import _broadcast_shapes
2121

22+
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
2223
from ._type_utils import _all_data_types, _can_cast
2324

2425

@@ -121,7 +122,7 @@ def where(condition, x1, x2):
121122
deps = []
122123
wait_list = []
123124
if x1_dtype != dst_dtype:
124-
_x1 = dpt.empty_like(x1, dtype=dst_dtype)
125+
_x1 = _empty_like_orderK(x1, dst_dtype)
125126
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
126127
src=x1, dst=_x1, sycl_queue=exec_q
127128
)
@@ -130,39 +131,22 @@ def where(condition, x1, x2):
130131
wait_list.append(ht_copy1_ev)
131132

132133
if x2_dtype != dst_dtype:
133-
_x2 = dpt.empty_like(x2, dtype=dst_dtype)
134+
_x2 = _empty_like_orderK(x2, dst_dtype)
134135
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
135136
src=x2, dst=_x2, sycl_queue=exec_q
136137
)
137138
x2 = _x2
138139
deps.append(copy2_ev)
139140
wait_list.append(ht_copy2_ev)
140141

142+
dst = _empty_like_triple_orderK(
143+
condition, x1, x2, dst_dtype, res_shape, dst_usm_type, exec_q
144+
)
145+
141146
condition = dpt.broadcast_to(condition, res_shape)
142147
x1 = dpt.broadcast_to(x1, res_shape)
143148
x2 = dpt.broadcast_to(x2, res_shape)
144149

145-
# dst is F-contiguous when all inputs are F contiguous
146-
# otherwise, defaults to C-contiguous
147-
if all(
148-
(
149-
condition.flags.fnc,
150-
x1.flags.fnc,
151-
x2.flags.fnc,
152-
)
153-
):
154-
order = "F"
155-
else:
156-
order = "C"
157-
158-
dst = dpt.empty(
159-
res_shape,
160-
dtype=dst_dtype,
161-
order=order,
162-
usm_type=dst_usm_type,
163-
sycl_queue=exec_q,
164-
)
165-
166150
hev, _ = ti._where(
167151
condition=condition,
168152
x1=x1,

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,45 @@ def test_where_compute_follows_data():
370370
dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q3), x1, x2)
371371
with pytest.raises(ExecutionPlacementError):
372372
dpt.where(x1, x1, x2)
373+
374+
375+
def test_where_order():
376+
get_queue_or_skip()
377+
378+
test_sh = (
379+
20,
380+
20,
381+
)
382+
test_sh2 = tuple(2 * dim for dim in test_sh)
383+
n = test_sh[-1]
384+
385+
for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]):
386+
ar1 = dpt.zeros(test_sh, dtype=dt1, order="C")
387+
ar2 = dpt.ones(test_sh, dtype=dt2, order="C")
388+
condition = dpt.zeros(test_sh, dtype="?", order="C")
389+
res = dpt.where(condition, ar1, ar2)
390+
assert res.flags.c_contiguous
391+
392+
ar1 = dpt.ones(test_sh, dtype=dt1, order="F")
393+
ar2 = dpt.ones(test_sh, dtype=dt2, order="F")
394+
condition = dpt.zeros(test_sh, dtype="?", order="F")
395+
res = dpt.where(condition, ar1, ar2)
396+
assert res.flags.f_contiguous
397+
398+
ar1 = dpt.ones(test_sh2, dtype=dt1, order="C")[:20, ::-2]
399+
ar2 = dpt.ones(test_sh2, dtype=dt2, order="C")[:20, ::-2]
400+
condition = dpt.zeros(test_sh2, dtype="?", order="C")[:20, ::-2]
401+
res = dpt.where(condition, ar1, ar2)
402+
assert res.strides == (n, -1)
403+
404+
ar1 = dpt.ones(test_sh2, dtype=dt1, order="C")[:20, ::-2].mT
405+
ar2 = dpt.ones(test_sh2, dtype=dt2, order="C")[:20, ::-2].mT
406+
condition = dpt.zeros(test_sh2, dtype="?", order="C")[:20, ::-2].mT
407+
res = dpt.where(condition, ar1, ar2)
408+
assert res.strides == (-1, n)
409+
410+
ar1 = dpt.ones(n, dtype=dt1, order="C")
411+
ar2 = dpt.broadcast_to(dpt.ones(n, dtype=dt2, order="C"), test_sh)
412+
condition = dpt.zeros(n, dtype="?", order="C")
413+
res = dpt.where(condition, ar1, ar2)
414+
assert res.strides == (20, 1)

0 commit comments

Comments
 (0)