@@ -80,6 +80,8 @@ def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
80
80
fx = x
81
81
else :
82
82
fx = dpt .reshape (x , (x .size ,), order = "C" )
83
+ if fx .size == 0 :
84
+ return fx
83
85
s = dpt .empty_like (fx , order = "C" )
84
86
host_tasks = []
85
87
if fx .flags .c_contiguous :
@@ -114,7 +116,7 @@ def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
114
116
fill_value = True , dst = unique_mask [0 ], sycl_queue = exec_q
115
117
)
116
118
host_tasks .append (ht_ev )
117
- cumsum = dpt .empty (s .shape , dtype = dpt .int64 )
119
+ cumsum = dpt .empty (s .shape , dtype = dpt .int64 , sycl_queue = exec_q )
118
120
# synchronizing call
119
121
n_uniques = mask_positions (
120
122
unique_mask , cumsum , sycl_queue = exec_q , depends = [one_ev , uneq_ev ]
@@ -163,10 +165,14 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
163
165
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
164
166
array_api_dev = x .device
165
167
exec_q = array_api_dev .sycl_queue
168
+ x_usm_type = x .usm_type
166
169
if x .ndim == 1 :
167
170
fx = x
168
171
else :
169
172
fx = dpt .reshape (x , (x .size ,), order = "C" )
173
+ ind_dt = default_device_index_type (exec_q )
174
+ if fx .size == 0 :
175
+ return UniqueCountsResult (fx , dpt .empty_like (fx , dtype = ind_dt ))
170
176
s = dpt .empty_like (fx , order = "C" )
171
177
host_tasks = []
172
178
if fx .flags .c_contiguous :
@@ -201,17 +207,21 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
201
207
fill_value = True , dst = unique_mask [0 ], sycl_queue = exec_q
202
208
)
203
209
host_tasks .append (ht_ev )
204
- ind_dt = default_device_index_type (exec_q )
205
- cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 )
210
+ cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 , sycl_queue = exec_q )
206
211
# synchronizing call
207
212
n_uniques = mask_positions (
208
213
unique_mask , cumsum , sycl_queue = exec_q , depends = [one_ev , uneq_ev ]
209
214
)
210
215
if n_uniques == fx .size :
211
216
dpctl .SyclEvent .wait_for (host_tasks )
212
- return UniqueCountsResult (s , dpt .ones (n_uniques , dtype = ind_dt ))
217
+ return UniqueCountsResult (
218
+ s ,
219
+ dpt .ones (
220
+ n_uniques , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
221
+ ),
222
+ )
213
223
unique_vals = dpt .empty (
214
- n_uniques , dtype = x .dtype , usm_type = x . usm_type , sycl_queue = exec_q
224
+ n_uniques , dtype = x .dtype , usm_type = x_usm_type , sycl_queue = exec_q
215
225
)
216
226
# populate unique values
217
227
ht_ev , _ = _extract (
@@ -224,7 +234,7 @@ def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
224
234
)
225
235
host_tasks .append (ht_ev )
226
236
unique_counts = dpt .empty (
227
- n_uniques + 1 , dtype = ind_dt , usm_type = x . usm_type , sycl_queue = exec_q
237
+ n_uniques + 1 , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
228
238
)
229
239
idx = dpt .empty (x .size , dtype = ind_dt , sycl_queue = exec_q )
230
240
ht_ev , id_ev = _linspace_step (start = 0 , dt = 1 , dst = idx , sycl_queue = exec_q )
@@ -281,13 +291,16 @@ def unique_inverse(x):
281
291
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
282
292
array_api_dev = x .device
283
293
exec_q = array_api_dev .sycl_queue
294
+ x_usm_type = x .usm_type
284
295
if x .ndim == 1 :
285
296
fx = x
286
297
else :
287
298
fx = dpt .reshape (x , (x .size ,), order = "C" )
288
299
ind_dt = default_device_index_type (exec_q )
289
300
sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
290
301
unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
302
+ if fx .size == 0 :
303
+ return UniqueInverseResult (fx , unsorting_ids )
291
304
host_tasks = []
292
305
if fx .flags .c_contiguous :
293
306
ht_ev , sort_ev = _argsort_ascending (
@@ -341,7 +354,7 @@ def unique_inverse(x):
341
354
fill_value = True , dst = unique_mask [0 ], sycl_queue = exec_q
342
355
)
343
356
host_tasks .append (ht_ev )
344
- cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 )
357
+ cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 , sycl_queue = exec_q )
345
358
# synchronizing call
346
359
n_uniques = mask_positions (
347
360
unique_mask , cumsum , sycl_queue = exec_q , depends = [uneq_ev , one_ev ]
@@ -350,7 +363,7 @@ def unique_inverse(x):
350
363
dpctl .SyclEvent .wait_for (host_tasks )
351
364
return UniqueInverseResult (s , unsorting_ids )
352
365
unique_vals = dpt .empty (
353
- n_uniques , dtype = x .dtype , usm_type = x . usm_type , sycl_queue = exec_q
366
+ n_uniques , dtype = x .dtype , usm_type = x_usm_type , sycl_queue = exec_q
354
367
)
355
368
ht_ev , _ = _extract (
356
369
src = s ,
@@ -362,7 +375,7 @@ def unique_inverse(x):
362
375
)
363
376
host_tasks .append (ht_ev )
364
377
cum_unique_counts = dpt .empty (
365
- n_uniques + 1 , dtype = ind_dt , usm_type = x . usm_type , sycl_queue = exec_q
378
+ n_uniques + 1 , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
366
379
)
367
380
idx = dpt .empty (x .size , dtype = ind_dt , sycl_queue = exec_q )
368
381
ht_ev , id_ev = _linspace_step (start = 0 , dt = 1 , dst = idx , sycl_queue = exec_q )
@@ -442,13 +455,20 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
442
455
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
443
456
array_api_dev = x .device
444
457
exec_q = array_api_dev .sycl_queue
458
+ x_usm_type = x .usm_type
445
459
if x .ndim == 1 :
446
460
fx = x
447
461
else :
448
- fx = dpt .reshape (x , (x .size ,), order = "C" , copy = False )
462
+ fx = dpt .reshape (x , (x .size ,), order = "C" )
449
463
ind_dt = default_device_index_type (exec_q )
450
464
sorting_ids = dpt .empty_like (fx , dtype = ind_dt , order = "C" )
451
465
unsorting_ids = dpt .empty_like (sorting_ids , dtype = ind_dt , order = "C" )
466
+ if fx .size == 0 :
467
+ # original array contains no data
468
+ # so it can be safely returned as values
469
+ return UniqueAllResult (
470
+ fx , sorting_ids , unsorting_ids , dpt .empty_like (fx , dtype = ind_dt )
471
+ )
452
472
host_tasks = []
453
473
if fx .flags .c_contiguous :
454
474
ht_ev , sort_ev = _argsort_ascending (
@@ -502,16 +522,24 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
502
522
fill_value = True , dst = unique_mask [0 ], sycl_queue = exec_q
503
523
)
504
524
host_tasks .append (ht_ev )
505
- cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 )
525
+ cumsum = dpt .empty (unique_mask .shape , dtype = dpt .int64 , sycl_queue = exec_q )
506
526
# synchronizing call
507
527
n_uniques = mask_positions (
508
528
unique_mask , cumsum , sycl_queue = exec_q , depends = [uneq_ev , one_ev ]
509
529
)
510
530
if n_uniques == fx .size :
511
531
dpctl .SyclEvent .wait_for (host_tasks )
512
- return UniqueInverseResult (s , unsorting_ids )
532
+ _counts = dpt .ones (
533
+ n_uniques , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
534
+ )
535
+ return UniqueAllResult (
536
+ s ,
537
+ sorting_ids ,
538
+ unsorting_ids ,
539
+ _counts ,
540
+ )
513
541
unique_vals = dpt .empty (
514
- n_uniques , dtype = x .dtype , usm_type = x . usm_type , sycl_queue = exec_q
542
+ n_uniques , dtype = x .dtype , usm_type = x_usm_type , sycl_queue = exec_q
515
543
)
516
544
ht_ev , _ = _extract (
517
545
src = s ,
@@ -523,7 +551,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
523
551
)
524
552
host_tasks .append (ht_ev )
525
553
cum_unique_counts = dpt .empty (
526
- n_uniques + 1 , dtype = ind_dt , usm_type = x . usm_type , sycl_queue = exec_q
554
+ n_uniques + 1 , dtype = ind_dt , usm_type = x_usm_type , sycl_queue = exec_q
527
555
)
528
556
idx = dpt .empty (x .size , dtype = ind_dt , sycl_queue = exec_q )
529
557
ht_ev , id_ev = _linspace_step (start = 0 , dt = 1 , dst = idx , sycl_queue = exec_q )
0 commit comments