Skip to content

Output of searchsorted must always have default indexing data type #1598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions dpctl/tensor/_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from ._copy_utils import _empty_like_orderK
from ._ctors import empty
from ._data_types import int32, int64
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
from ._tensor_impl import _take as ti_take
from ._tensor_impl import (
default_device_index_type as ti_default_device_index_type,
)
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
from ._type_utils import iinfo, isdtype, result_type
from ._type_utils import isdtype, result_type
from ._usmarray import usm_ndarray


Expand Down Expand Up @@ -141,9 +143,9 @@ def searchsorted(
x2 = x2_buf

dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
dst_dt = int32 if x2.size <= iinfo(int32).max else int64
index_dt = ti_default_device_index_type(q)

dst = _empty_like_orderK(x2, dst_dt, usm_type=dst_usm_type)
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)

if side == "left":
ht_ev, _ = _searchsorted_left(
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np):
assert hay_stack.dtype == needles.dtype
assert hay_stack.ndim == 1

info_ = dpt.__array_namespace_info__()
default_dts_dev = info_.default_dtypes(hay_stack.device)
index_dt = default_dts_dev["indexing"]

p_left = dpt.searchsorted(hay_stack, needles, side="left")
assert p_left.dtype == index_dt

hs_np = dpt.asnumpy(hay_stack)
ref_left = np.searchsorted(hs_np, needles_np, side="left")
assert dpt.all(p_left == dpt.asarray(ref_left))

p_right = dpt.searchsorted(hay_stack, needles, side="right")
assert p_right.dtype == index_dt

ref_right = np.searchsorted(hs_np, needles_np, side="right")
assert dpt.all(p_right == dpt.asarray(ref_right))

sorter = dpt.arange(hay_stack.size)
ps_left = dpt.searchsorted(hay_stack, needles, side="left", sorter=sorter)
assert ps_left.dtype == index_dt
assert dpt.all(ps_left == p_left)
ps_right = dpt.searchsorted(hay_stack, needles, side="right", sorter=sorter)
assert ps_right.dtype == index_dt
assert dpt.all(ps_right == p_right)


Expand Down