Skip to content

Commit e0b79ff

Browse files
committed
Added missing usm_type to tril() and triu() functions.
1 parent 7c15231 commit e0b79ff

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

dpctl/tensor/_ctors.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,19 +1247,31 @@ def tril(X, k=0):
12471247

12481248
if k >= shape[nd - 1] - 1:
12491249
res = dpt.empty(
1250-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1250+
X.shape,
1251+
dtype=X.dtype,
1252+
order=order,
1253+
usm_type=X.usm_type,
1254+
sycl_queue=X.sycl_queue,
12511255
)
12521256
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
12531257
src=X, dst=res, sycl_queue=X.sycl_queue
12541258
)
12551259
hev.wait()
12561260
elif k < -shape[nd - 2]:
12571261
res = dpt.zeros(
1258-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1262+
X.shape,
1263+
dtype=X.dtype,
1264+
order=order,
1265+
usm_type=X.usm_type,
1266+
sycl_queue=X.sycl_queue,
12591267
)
12601268
else:
12611269
res = dpt.empty(
1262-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1270+
X.shape,
1271+
dtype=X.dtype,
1272+
order=order,
1273+
usm_type=X.usm_type,
1274+
sycl_queue=X.sycl_queue,
12631275
)
12641276
hev, _ = ti._tril(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
12651277
hev.wait()
@@ -1290,19 +1302,31 @@ def triu(X, k=0):
12901302

12911303
if k > shape[nd - 1]:
12921304
res = dpt.zeros(
1293-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1305+
X.shape,
1306+
dtype=X.dtype,
1307+
order=order,
1308+
usm_type=X.usm_type,
1309+
sycl_queue=X.sycl_queue,
12941310
)
12951311
elif k <= -shape[nd - 2] + 1:
12961312
res = dpt.empty(
1297-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1313+
X.shape,
1314+
dtype=X.dtype,
1315+
order=order,
1316+
usm_type=X.usm_type,
1317+
sycl_queue=X.sycl_queue,
12981318
)
12991319
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
13001320
src=X, dst=res, sycl_queue=X.sycl_queue
13011321
)
13021322
hev.wait()
13031323
else:
13041324
res = dpt.empty(
1305-
X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue
1325+
X.shape,
1326+
dtype=X.dtype,
1327+
order=order,
1328+
usm_type=X.usm_type,
1329+
sycl_queue=X.sycl_queue,
13061330
)
13071331
hev, _ = ti._triu(src=X, dst=res, k=k, sycl_queue=X.sycl_queue)
13081332
hev.wait()

0 commit comments

Comments
 (0)