|
29 | 29 | default_device_index_type,
|
30 | 30 | mask_positions,
|
31 | 31 | )
|
32 |
| -from ._tensor_sorting_impl import _argsort_ascending, _sort_ascending |
| 32 | +from ._tensor_sorting_impl import ( |
| 33 | + _argsort_ascending, |
| 34 | + _searchsorted_left, |
| 35 | + _sort_ascending, |
| 36 | +) |
33 | 37 |
|
34 | 38 | __all__ = [
|
35 | 39 | "unique_values",
|
@@ -365,7 +369,7 @@ def unique_inverse(x):
|
365 | 369 | unique_vals = dpt.empty(
|
366 | 370 | n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
|
367 | 371 | )
|
368 |
| - ht_ev, _ = _extract( |
| 372 | + ht_ev, uv_ev = _extract( |
369 | 373 | src=s,
|
370 | 374 | cumsum=cumsum,
|
371 | 375 | axis_start=0,
|
@@ -403,8 +407,21 @@ def unique_inverse(x):
|
403 | 407 | depends=[set_ev, extr_ev],
|
404 | 408 | )
|
405 | 409 | host_tasks.append(ht_ev)
|
| 410 | + |
| 411 | + inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32 |
| 412 | + inv = dpt.empty_like(x, dtype=inv_dt, order="C") |
| 413 | + ht_ev, _ = _searchsorted_left( |
| 414 | + hay=unique_vals, |
| 415 | + needles=x, |
| 416 | + positions=inv, |
| 417 | + sycl_queue=exec_q, |
| 418 | + depends=[ |
| 419 | + uv_ev, |
| 420 | + ], |
| 421 | + ) |
| 422 | + host_tasks.append(ht_ev) |
| 423 | + |
406 | 424 | dpctl.SyclEvent.wait_for(host_tasks)
|
407 |
| - inv = dpt.searchsorted(unique_vals, x) |
408 | 425 | return UniqueInverseResult(unique_vals, inv)
|
409 | 426 |
|
410 | 427 |
|
@@ -532,7 +549,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
|
532 | 549 | unique_vals = dpt.empty(
|
533 | 550 | n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
|
534 | 551 | )
|
535 |
| - ht_ev, _ = _extract( |
| 552 | + ht_ev, uv_ev = _extract( |
536 | 553 | src=s,
|
537 | 554 | cumsum=cumsum,
|
538 | 555 | axis_start=0,
|
@@ -562,15 +579,29 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
|
562 | 579 | )
|
563 | 580 | host_tasks.append(ht_ev)
|
564 | 581 | _counts = dpt.empty_like(cum_unique_counts[1:])
|
565 |
| - ht_ev, _ = _subtract( |
| 582 | + ht_ev, sub_ev = _subtract( |
566 | 583 | src1=cum_unique_counts[1:],
|
567 | 584 | src2=cum_unique_counts[:-1],
|
568 | 585 | dst=_counts,
|
569 | 586 | sycl_queue=exec_q,
|
570 | 587 | depends=[set_ev, extr_ev],
|
571 | 588 | )
|
572 | 589 | host_tasks.append(ht_ev)
|
573 |
| - inv = dpt.searchsorted(unique_vals, x) |
| 590 | + |
| 591 | + inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32 |
| 592 | + inv = dpt.empty_like(x, dtype=inv_dt, order="C") |
| 593 | + ht_ev, _ = _searchsorted_left( |
| 594 | + hay=unique_vals, |
| 595 | + needles=x, |
| 596 | + positions=inv, |
| 597 | + sycl_queue=exec_q, |
| 598 | + depends=[ |
| 599 | + uv_ev, |
| 600 | + ], |
| 601 | + ) |
| 602 | + host_tasks.append(ht_ev) |
| 603 | + |
| 604 | + dpctl.SyclEvent.wait_for(host_tasks) |
574 | 605 | return UniqueAllResult(
|
575 | 606 | unique_vals,
|
576 | 607 | sorting_ids[cum_unique_counts[:-1]],
|
|
0 commit comments