Skip to content

Commit 4ed654f

Browse files
Merge pull request #1062 from IntelPython/fix_tril_triu_usm_type
Added missing usm_type to tril() and triu() functions.
2 parents 7c15231 + ec8509f commit 4ed654f

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

dpctl/tensor/_ctors.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,19 +1247,31 @@ def tril(X, k=0):
12471247

12481248
if k >= shape[nd - 1] - 1:
12491249
res = dpt.empty(
1250-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1250+
X.shape,
1251+
dtype=X.dtype,
1252+
order=order,
1253+
usm_type=X.usm_type,
1254+
sycl_queue=X.sycl_queue,
12511255
)
12521256
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
12531257
src=X, dst=res, sycl_queue=X.sycl_queue
12541258
)
12551259
hev.wait()
12561260
elif k < -shape[nd - 2]:
12571261
res = dpt.zeros(
1258-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1262+
X.shape,
1263+
dtype=X.dtype,
1264+
order=order,
1265+
usm_type=X.usm_type,
1266+
sycl_queue=X.sycl_queue,
12591267
)
12601268
else:
12611269
res = dpt.empty(
1262-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1270+
X.shape,
1271+
dtype=X.dtype,
1272+
order=order,
1273+
usm_type=X.usm_type,
1274+
sycl_queue=X.sycl_queue,
12631275
)
12641276
hev, _ = ti._tril(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
12651277
hev.wait()
@@ -1290,19 +1302,31 @@ def triu(X, k=0):
12901302

12911303
if k > shape[nd - 1]:
12921304
res = dpt.zeros(
1293-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1305+
X.shape,
1306+
dtype=X.dtype,
1307+
order=order,
1308+
usm_type=X.usm_type,
1309+
sycl_queue=X.sycl_queue,
12941310
)
12951311
elif k <= -shape[nd - 2] + 1:
12961312
res = dpt.empty(
1297-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1313+
X.shape,
1314+
dtype=X.dtype,
1315+
order=order,
1316+
usm_type=X.usm_type,
1317+
sycl_queue=X.sycl_queue,
12981318
)
12991319
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
13001320
src=X, dst=res, sycl_queue=X.sycl_queue
13011321
)
13021322
hev.wait()
13031323
else:
13041324
res = dpt.empty(
1305-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1325+
X.shape,
1326+
dtype=X.dtype,
1327+
order=order,
1328+
usm_type=X.usm_type,
1329+
sycl_queue=X.sycl_queue,
13061330
)
13071331
hev, _ = ti._triu(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
13081332
hev.wait()

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,28 @@ def test_triu(dtype):
15031503
assert np.array_equal(Ynp, dpt.asnumpy(Y))
15041504

15051505

1506+
@pytest.mark.parametrize("tri_fn", [dpt.tril, dpt.triu])
1507+
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
1508+
def test_tri_usm_type(tri_fn, usm_type):
1509+
q = get_queue_or_skip()
1510+
dtype = dpt.uint16
1511+
1512+
shape = (2, 3, 4, 5, 5)
1513+
size = np.prod(shape)
1514+
X = dpt.reshape(
1515+
dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape
1516+
)
1517+
Y = tri_fn(X) # main execution branch
1518+
assert Y.usm_type == X.usm_type
1519+
assert Y.sycl_queue == q
1520+
Y = tri_fn(X, k=-6) # special case of Y == X
1521+
assert Y.usm_type == X.usm_type
1522+
assert Y.sycl_queue == q
1523+
Y = tri_fn(X, k=6) # special case of Y == 0
1524+
assert Y.usm_type == X.usm_type
1525+
assert Y.sycl_queue == q
1526+
1527+
15061528
def test_tril_slice():
15071529
q = get_queue_or_skip()
15081530

0 commit comments

Comments
 (0)