Skip to content

Commit ac489d9

Browse files
authored
Merge branch 'master' into dpnp_tensordot
2 parents 4f85cd7 + 4c7859b commit ac489d9

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,12 @@ def _build_along_axis_index(a, indices, axis):
111111
else:
112112
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
113113
fancy_index.append(
114-
dpnp.arange(n, dtype=indices.dtype).reshape(ind_shape)
114+
dpnp.arange(
115+
n,
116+
dtype=indices.dtype,
117+
usm_type=indices.usm_type,
118+
sycl_queue=indices.sycl_queue,
119+
).reshape(ind_shape)
115120
)
116121

117122
return tuple(fancy_index)

tests/test_sycl_queue.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,28 +1525,56 @@ def test_clip(device):
15251525
assert_sycl_queue_equal(x.sycl_queue, y.sycl_queue)
15261526

15271527

1528-
@pytest.mark.parametrize("func", ["take", "take_along_axis"])
15291528
@pytest.mark.parametrize(
15301529
"device",
15311530
valid_devices,
15321531
ids=[device.filter_string for device in valid_devices],
15331532
)
1534-
def test_take(func, device):
1533+
def test_take(device):
15351534
numpy_data = numpy.arange(5)
15361535
dpnp_data = dpnp.array(numpy_data, device=device)
15371536

15381537
dpnp_ind = dpnp.array([0, 2, 4], device=device)
15391538
np_ind = dpnp_ind.asnumpy()
15401539

1541-
result = getattr(dpnp, func)(dpnp_data, dpnp_ind, axis=None)
1542-
expected = getattr(numpy, func)(numpy_data, np_ind, axis=None)
1540+
result = dpnp.take(dpnp_data, dpnp_ind, axis=None)
1541+
expected = numpy.take(numpy_data, np_ind, axis=None)
15431542
assert_allclose(expected, result)
15441543

15451544
expected_queue = dpnp_data.get_array().sycl_queue
15461545
result_queue = result.get_array().sycl_queue
15471546
assert_sycl_queue_equal(result_queue, expected_queue)
15481547

15491548

1549+
@pytest.mark.parametrize(
1550+
"data, ind, axis",
1551+
[
1552+
(numpy.arange(6), numpy.array([0, 2, 4]), None),
1553+
(
1554+
numpy.arange(6).reshape((2, 3)),
1555+
numpy.array([0, 1]).reshape((2, 1)),
1556+
1,
1557+
),
1558+
],
1559+
)
1560+
@pytest.mark.parametrize(
1561+
"device",
1562+
valid_devices,
1563+
ids=[device.filter_string for device in valid_devices],
1564+
)
1565+
def test_take_along_axis(data, ind, axis, device):
1566+
dp_data = dpnp.array(data, device=device)
1567+
dp_ind = dpnp.array(ind, device=device)
1568+
1569+
result = dpnp.take_along_axis(dp_data, dp_ind, axis=axis)
1570+
expected = numpy.take_along_axis(data, ind, axis=axis)
1571+
assert_allclose(expected, result)
1572+
1573+
expected_queue = dp_data.get_array().sycl_queue
1574+
result_queue = result.get_array().sycl_queue
1575+
assert_sycl_queue_equal(result_queue, expected_queue)
1576+
1577+
15501578
@pytest.mark.parametrize(
15511579
"device",
15521580
valid_devices,

tests/test_usm_type.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,32 @@ def test_take(func, usm_type_x, usm_type_ind):
575575
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
576576

577577

578+
@pytest.mark.parametrize(
579+
"data, ind, axis",
580+
[
581+
(numpy.arange(6), numpy.array([0, 2, 4]), None),
582+
(
583+
numpy.arange(6).reshape((2, 3)),
584+
numpy.array([0, 1]).reshape((2, 1)),
585+
1,
586+
),
587+
],
588+
)
589+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
590+
@pytest.mark.parametrize(
591+
"usm_type_ind", list_of_usm_types, ids=list_of_usm_types
592+
)
593+
def test_take_along_axis(data, ind, axis, usm_type_x, usm_type_ind):
594+
x = dp.array(data, usm_type=usm_type_x)
595+
ind = dp.array(ind, usm_type=usm_type_ind)
596+
597+
z = dp.take_along_axis(x, ind, axis=axis)
598+
599+
assert x.usm_type == usm_type_x
600+
assert ind.usm_type == usm_type_ind
601+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
602+
603+
578604
@pytest.mark.parametrize(
579605
"data, is_empty",
580606
[

0 commit comments

Comments
 (0)