Skip to content

Commit d12fad5

Browse files
committed
Update sorting tests
1 parent 10c0523 commit d12fad5

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

dpnp/tests/third_party/cupy/sorting_tests/test_search.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,31 @@ def test_where_two_arrays(self, xp, cond_type, x_type, y_type):
328328
return xp.where(cond, x, y)
329329

330330

331+
@testing.with_requires("numpy>=2.0")
332+
@testing.parameterize(
333+
{"scalar_value": 1},
334+
{"scalar_value": 1.0},
335+
{"scalar_value": 1 + 2j},
336+
)
337+
class TestWhereArrayAndScalar:
338+
339+
@testing.for_all_dtypes()
340+
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
341+
def test_where_array_scalar(self, xp, dtype):
342+
cond = testing.shaped_random((2, 3, 4), xp, xp.bool_)
343+
x = testing.shaped_random((2, 3, 4), xp, dtype, seed=0)
344+
y = self.scalar_value
345+
return xp.where(cond, x, y)
346+
347+
@testing.for_all_dtypes()
348+
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
349+
def test_where_scalar_array(self, xp, dtype):
350+
cond = testing.shaped_random((2, 3, 4), xp, xp.bool_)
351+
x = self.scalar_value
352+
y = testing.shaped_random((2, 3, 4), xp, dtype, seed=0)
353+
return xp.where(cond, x, y)
354+
355+
331356
@testing.parameterize(
332357
{"cond_shape": (2, 3, 4)},
333358
{"cond_shape": (4,)},

0 commit comments

Comments
 (0)