Skip to content

Commit afa6980

Browse files
vtavanaantonwolfy
andauthored
fix CFD issue for dpnp.extract (#2172)
Co-authored-by: Anton <[email protected]>
1 parent 9a2b917 commit afa6980

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@
5151
dpnp_putmask,
5252
)
5353
from .dpnp_array import dpnp_array
54-
from .dpnp_utils import (
55-
call_origin,
56-
get_usm_allocations,
57-
)
54+
from .dpnp_utils import call_origin, get_usm_allocations
5855

5956
__all__ = [
6057
"choose",
@@ -535,11 +532,12 @@ def extract(condition, a):
535532
"""
536533

537534
usm_a = dpnp.get_usm_ndarray(a)
535+
usm_type, exec_q = get_usm_allocations([usm_a, condition])
538536
usm_cond = dpnp.as_usm_ndarray(
539537
condition,
540538
dtype=dpnp.bool,
541-
usm_type=usm_a.usm_type,
542-
sycl_queue=usm_a.sycl_queue,
539+
usm_type=usm_type,
540+
sycl_queue=exec_q,
543541
)
544542

545543
if usm_cond.size != usm_a.size:

tests/test_sycl_queue.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,20 @@ def test_copy_operation(device):
406406
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
407407

408408

409+
@pytest.mark.parametrize(
410+
"device",
411+
valid_devices,
412+
ids=[device.filter_string for device in valid_devices],
413+
)
414+
def test_extract(device):
415+
x = dpnp.arange(3, device=device)
416+
y = dpnp.array([True, False, True], device=device)
417+
result = dpnp.extract(x, y)
418+
419+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
420+
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
421+
422+
409423
@pytest.mark.parametrize(
410424
"device",
411425
valid_devices,

tests/test_usm_type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,18 @@ def test_concat_stack(func, data1, data2, usm_type_x, usm_type_y):
812812
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
813813

814814

815+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
816+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
817+
def test_extract(usm_type_x, usm_type_y):
818+
x = dp.arange(3, usm_type=usm_type_x)
819+
y = dp.array([True, False, True], usm_type=usm_type_y)
820+
z = dp.extract(y, x)
821+
822+
assert x.usm_type == usm_type_x
823+
assert y.usm_type == usm_type_y
824+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
825+
826+
815827
@pytest.mark.parametrize(
816828
"func,data1",
817829
[

0 commit comments

Comments
 (0)