Skip to content

Commit f92bfc4

Browse files
Test that return array from searchsorted has default indexing data type
1 parent 501ff54 commit f92bfc4

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np):
1111
assert hay_stack.dtype == needles.dtype
1212
assert hay_stack.ndim == 1
1313

14+
info_ = dpt.__array_namespace_info__()
15+
default_dts_dev = info_.default_dtypes(hay_stack.device)
16+
index_dt = default_dts_dev["indexing"]
17+
1418
p_left = dpt.searchsorted(hay_stack, needles, side="left")
19+
assert p_left.dtype == index_dt
1520

1621
hs_np = dpt.asnumpy(hay_stack)
1722
ref_left = np.searchsorted(hs_np, needles_np, side="left")
1823
assert dpt.all(p_left == dpt.asarray(ref_left))
1924

2025
p_right = dpt.searchsorted(hay_stack, needles, side="right")
26+
assert p_right.dtype == index_dt
27+
2128
ref_right = np.searchsorted(hs_np, needles_np, side="right")
2229
assert dpt.all(p_right == dpt.asarray(ref_right))
2330

2431
sorter = dpt.arange(hay_stack.size)
2532
ps_left = dpt.searchsorted(hay_stack, needles, side="left", sorter=sorter)
33+
assert ps_left.dtype == index_dt
2634
assert dpt.all(ps_left == p_left)
2735
ps_right = dpt.searchsorted(hay_stack, needles, side="right", sorter=sorter)
36+
assert ps_right.dtype == index_dt
2837
assert dpt.all(ps_right == p_right)
2938

3039

0 commit comments

Comments
 (0)