@@ -11,20 +11,29 @@ def _check(hay_stack, needles, needles_np):
11
11
assert hay_stack .dtype == needles .dtype
12
12
assert hay_stack .ndim == 1
13
13
14
+ info_ = dpt .__array_namespace_info__ ()
15
+ default_dts_dev = info_ .default_dtypes (hay_stack .device )
16
+ index_dt = default_dts_dev ["indexing" ]
17
+
14
18
p_left = dpt .searchsorted (hay_stack , needles , side = "left" )
19
+ assert p_left .dtype == index_dt
15
20
16
21
hs_np = dpt .asnumpy (hay_stack )
17
22
ref_left = np .searchsorted (hs_np , needles_np , side = "left" )
18
23
assert dpt .all (p_left == dpt .asarray (ref_left ))
19
24
20
25
p_right = dpt .searchsorted (hay_stack , needles , side = "right" )
26
+ assert p_right .dtype == index_dt
27
+
21
28
ref_right = np .searchsorted (hs_np , needles_np , side = "right" )
22
29
assert dpt .all (p_right == dpt .asarray (ref_right ))
23
30
24
31
sorter = dpt .arange (hay_stack .size )
25
32
ps_left = dpt .searchsorted (hay_stack , needles , side = "left" , sorter = sorter )
33
+ assert ps_left .dtype == index_dt
26
34
assert dpt .all (ps_left == p_left )
27
35
ps_right = dpt .searchsorted (hay_stack , needles , side = "right" , sorter = sorter )
36
+ assert ps_right .dtype == index_dt
28
37
assert dpt .all (ps_right == p_right )
29
38
30
39
0 commit comments