Skip to content

Commit 5472b28

Browse files
vtavanaantonwolfy
andcommitted
fix CFD issue for dpnp.extract (#2172)
Co-authored-by: Anton <[email protected]>
1 parent 18a8eb5 commit 5472b28

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@
5151
dpnp_putmask,
5252
)
5353
from .dpnp_array import dpnp_array
54-
from .dpnp_utils import (
55-
call_origin,
56-
get_usm_allocations,
57-
)
54+
from .dpnp_utils import call_origin, get_usm_allocations
5855

5956
__all__ = [
6057
"choose",
@@ -585,11 +582,12 @@ def extract(condition, a):
585582
"""
586583

587584
usm_a = dpnp.get_usm_ndarray(a)
585+
usm_type, exec_q = get_usm_allocations([usm_a, condition])
588586
usm_cond = dpnp.as_usm_ndarray(
589587
condition,
590588
dtype=dpnp.bool,
591-
usm_type=usm_a.usm_type,
592-
sycl_queue=usm_a.sycl_queue,
589+
usm_type=usm_type,
590+
sycl_queue=exec_q,
593591
)
594592

595593
if usm_cond.size != usm_a.size:

tests/test_sycl_queue.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,20 @@ def test_copy_operation(device):
406406
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
407407

408408

409+
@pytest.mark.parametrize(
410+
"device",
411+
valid_devices,
412+
ids=[device.filter_string for device in valid_devices],
413+
)
414+
def test_extract(device):
415+
x = dpnp.arange(3, device=device)
416+
y = dpnp.array([True, False, True], device=device)
417+
result = dpnp.extract(x, y)
418+
419+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
420+
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
421+
422+
409423
@pytest.mark.parametrize(
410424
"device",
411425
valid_devices,

tests/test_usm_type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,18 @@ def test_concat_stack(func, data1, data2, usm_type_x, usm_type_y):
769769
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
770770

771771

772+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
773+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
774+
def test_extract(usm_type_x, usm_type_y):
775+
x = dp.arange(3, usm_type=usm_type_x)
776+
y = dp.array([True, False, True], usm_type=usm_type_y)
777+
z = dp.extract(y, x)
778+
779+
assert x.usm_type == usm_type_x
780+
assert y.usm_type == usm_type_y
781+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
782+
783+
772784
@pytest.mark.parametrize(
773785
"func,data1",
774786
[

0 commit comments

Comments
 (0)