Skip to content

Commit 4c7859b

Browse files
Fix ExecutionPlacementError for dpnp.take_along_axis (#1702)
* Follow compute follows data to fill fancy_index * Update take_along_axis tests to cover the issue * Update test_take_along_axis
1 parent 0957ddd commit 4c7859b

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,12 @@ def _build_along_axis_index(a, indices, axis):
111111
else:
112112
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
113113
fancy_index.append(
114-
dpnp.arange(n, dtype=indices.dtype).reshape(ind_shape)
114+
dpnp.arange(
115+
n,
116+
dtype=indices.dtype,
117+
usm_type=indices.usm_type,
118+
sycl_queue=indices.sycl_queue,
119+
).reshape(ind_shape)
115120
)
116121

117122
return tuple(fancy_index)

tests/test_sycl_queue.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,28 +1520,56 @@ def test_clip(device):
15201520
assert_sycl_queue_equal(x.sycl_queue, y.sycl_queue)
15211521

15221522

1523-
@pytest.mark.parametrize("func", ["take", "take_along_axis"])
15241523
@pytest.mark.parametrize(
15251524
"device",
15261525
valid_devices,
15271526
ids=[device.filter_string for device in valid_devices],
15281527
)
1529-
def test_take(func, device):
1528+
def test_take(device):
15301529
numpy_data = numpy.arange(5)
15311530
dpnp_data = dpnp.array(numpy_data, device=device)
15321531

15331532
dpnp_ind = dpnp.array([0, 2, 4], device=device)
15341533
np_ind = dpnp_ind.asnumpy()
15351534

1536-
result = getattr(dpnp, func)(dpnp_data, dpnp_ind, axis=None)
1537-
expected = getattr(numpy, func)(numpy_data, np_ind, axis=None)
1535+
result = dpnp.take(dpnp_data, dpnp_ind, axis=None)
1536+
expected = numpy.take(numpy_data, np_ind, axis=None)
15381537
assert_allclose(expected, result)
15391538

15401539
expected_queue = dpnp_data.get_array().sycl_queue
15411540
result_queue = result.get_array().sycl_queue
15421541
assert_sycl_queue_equal(result_queue, expected_queue)
15431542

15441543

1544+
@pytest.mark.parametrize(
1545+
"data, ind, axis",
1546+
[
1547+
(numpy.arange(6), numpy.array([0, 2, 4]), None),
1548+
(
1549+
numpy.arange(6).reshape((2, 3)),
1550+
numpy.array([0, 1]).reshape((2, 1)),
1551+
1,
1552+
),
1553+
],
1554+
)
1555+
@pytest.mark.parametrize(
1556+
"device",
1557+
valid_devices,
1558+
ids=[device.filter_string for device in valid_devices],
1559+
)
1560+
def test_take_along_axis(data, ind, axis, device):
1561+
dp_data = dpnp.array(data, device=device)
1562+
dp_ind = dpnp.array(ind, device=device)
1563+
1564+
result = dpnp.take_along_axis(dp_data, dp_ind, axis=axis)
1565+
expected = numpy.take_along_axis(data, ind, axis=axis)
1566+
assert_allclose(expected, result)
1567+
1568+
expected_queue = dp_data.get_array().sycl_queue
1569+
result_queue = result.get_array().sycl_queue
1570+
assert_sycl_queue_equal(result_queue, expected_queue)
1571+
1572+
15451573
@pytest.mark.parametrize(
15461574
"device",
15471575
valid_devices,

tests/test_usm_type.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,32 @@ def test_take(func, usm_type_x, usm_type_ind):
570570
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
571571

572572

573+
@pytest.mark.parametrize(
574+
"data, ind, axis",
575+
[
576+
(numpy.arange(6), numpy.array([0, 2, 4]), None),
577+
(
578+
numpy.arange(6).reshape((2, 3)),
579+
numpy.array([0, 1]).reshape((2, 1)),
580+
1,
581+
),
582+
],
583+
)
584+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
585+
@pytest.mark.parametrize(
586+
"usm_type_ind", list_of_usm_types, ids=list_of_usm_types
587+
)
588+
def test_take_along_axis(data, ind, axis, usm_type_x, usm_type_ind):
589+
x = dp.array(data, usm_type=usm_type_x)
590+
ind = dp.array(ind, usm_type=usm_type_ind)
591+
592+
z = dp.take_along_axis(x, ind, axis=axis)
593+
594+
assert x.usm_type == usm_type_x
595+
assert ind.usm_type == usm_type_ind
596+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
597+
598+
573599
@pytest.mark.parametrize(
574600
"data, is_empty",
575601
[

0 commit comments

Comments
 (0)