@@ -1247,19 +1247,31 @@ def tril(X, k=0):
1247
1247
1248
1248
if k >= shape [nd - 1 ] - 1 :
1249
1249
res = dpt .empty (
1250
- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1250
+ X .shape ,
1251
+ dtype = X .dtype ,
1252
+ order = order ,
1253
+ usm_type = X .usm_type ,
1254
+ sycl_queue = X .sycl_queue ,
1251
1255
)
1252
1256
hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
1253
1257
src = X , dst = res , sycl_queue = X .sycl_queue
1254
1258
)
1255
1259
hev .wait ()
1256
1260
elif k < - shape [nd - 2 ]:
1257
1261
res = dpt .zeros (
1258
- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1262
+ X .shape ,
1263
+ dtype = X .dtype ,
1264
+ order = order ,
1265
+ usm_type = X .usm_type ,
1266
+ sycl_queue = X .sycl_queue ,
1259
1267
)
1260
1268
else :
1261
1269
res = dpt .empty (
1262
- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1270
+ X .shape ,
1271
+ dtype = X .dtype ,
1272
+ order = order ,
1273
+ usm_type = X .usm_type ,
1274
+ sycl_queue = X .sycl_queue ,
1263
1275
)
1264
1276
hev , _ = ti ._tril (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
1265
1277
hev .wait ()
@@ -1290,19 +1302,31 @@ def triu(X, k=0):
1290
1302
1291
1303
if k > shape [nd - 1 ]:
1292
1304
res = dpt .zeros (
1293
- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1305
+ X .shape ,
1306
+ dtype = X .dtype ,
1307
+ order = order ,
1308
+ usm_type = X .usm_type ,
1309
+ sycl_queue = X .sycl_queue ,
1294
1310
)
1295
1311
elif k <= - shape [nd - 2 ] + 1 :
1296
1312
res = dpt .empty (
1297
- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1313
+ X .shape ,
1314
+ dtype = X .dtype ,
1315
+ order = order ,
1316
+ usm_type = X .usm_type ,
1317
+ sycl_queue = X .sycl_queue ,
1298
1318
)
1299
1319
hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
1300
1320
src = X , dst = res , sycl_queue = X .sycl_queue
1301
1321
)
1302
1322
hev .wait ()
1303
1323
else :
1304
1324
res = dpt .empty (
1305
- X .shape , dtype = X .dtype , order = order , sycl_queue = X .sycl_queue
1325
+ X .shape ,
1326
+ dtype = X .dtype ,
1327
+ order = order ,
1328
+ usm_type = X .usm_type ,
1329
+ sycl_queue = X .sycl_queue ,
1306
1330
)
1307
1331
hev , _ = ti ._triu (src = X , dst = res , k = k , sycl_queue = X .sycl_queue )
1308
1332
hev .wait ()
0 commit comments