Skip to content

Commit b67e5f6

Browse files
authored
Backport gh-2172 (#2215)
This PR backports of #2172 from development branch to `maintenance/0.16.x`.
1 parent 309a418 commit b67e5f6

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ This is a bug-fix release.
1515

1616
### Fixed
1717

18+
* Resolved an issue with Compute Follows Data inconsistency in `dpnp.extract` function [#2172](https://github.com/IntelPython/dpnp/pull/2172)
1819
* Resolved a compilation error when building with DPC++ 2025.1 compiler [#2211](https://github.com/IntelPython/dpnp/pull/2211)
1920

2021

dpnp/dpnp_iface_indexing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@
5151
dpnp_putmask,
5252
)
5353
from .dpnp_array import dpnp_array
54-
from .dpnp_utils import (
55-
call_origin,
56-
get_usm_allocations,
57-
)
54+
from .dpnp_utils import call_origin, get_usm_allocations
5855

5956
__all__ = [
6057
"choose",
@@ -585,11 +582,12 @@ def extract(condition, a):
585582
"""
586583

587584
usm_a = dpnp.get_usm_ndarray(a)
585+
usm_type, exec_q = get_usm_allocations([usm_a, condition])
588586
usm_cond = dpnp.as_usm_ndarray(
589587
condition,
590588
dtype=dpnp.bool,
591-
usm_type=usm_a.usm_type,
592-
sycl_queue=usm_a.sycl_queue,
589+
usm_type=usm_type,
590+
sycl_queue=exec_q,
593591
)
594592

595593
if usm_cond.size != usm_a.size:

tests/test_sycl_queue.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,20 @@ def test_copy_operation(device):
406406
assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue)
407407

408408

409+
@pytest.mark.parametrize(
410+
"device",
411+
valid_devices,
412+
ids=[device.filter_string for device in valid_devices],
413+
)
414+
def test_extract(device):
415+
x = dpnp.arange(3, device=device)
416+
y = dpnp.array([True, False, True], device=device)
417+
result = dpnp.extract(x, y)
418+
419+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
420+
assert_sycl_queue_equal(result.sycl_queue, y.sycl_queue)
421+
422+
409423
@pytest.mark.parametrize(
410424
"device",
411425
valid_devices,

tests/test_usm_type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,18 @@ def test_concat_stack(func, data1, data2, usm_type_x, usm_type_y):
769769
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
770770

771771

772+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
773+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
774+
def test_extract(usm_type_x, usm_type_y):
775+
x = dp.arange(3, usm_type=usm_type_x)
776+
y = dp.array([True, False, True], usm_type=usm_type_y)
777+
z = dp.extract(y, x)
778+
779+
assert x.usm_type == usm_type_x
780+
assert y.usm_type == usm_type_y
781+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
782+
783+
772784
@pytest.mark.parametrize(
773785
"func,data1",
774786
[

0 commit comments

Comments
 (0)