Skip to content

Commit aca8aa0

Browse files
committed
Add tests for take with out keyword
1 parent 1b38641 commit aca8aa0

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,3 +1890,75 @@ def test_put_along_axis_uint64_indices():
18901890
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1)
18911891
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5))
18921892
assert dpt.all(expected == x)
1893+
1894+
1895+
@pytest.mark.parametrize(
1896+
"data_dt",
1897+
_all_dtypes,
1898+
)
1899+
@pytest.mark.parametrize("order", ["C", "F"])
1900+
def test_take_out(data_dt, order):
1901+
q = get_queue_or_skip()
1902+
skip_if_dtype_not_supported(data_dt, q)
1903+
1904+
axis = 0
1905+
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
1906+
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
1907+
out_sh = x.shape[:axis] + ind.shape + x.shape[axis + 1 :]
1908+
out = dpt.empty(out_sh, dtype=data_dt, sycl_queue=q)
1909+
1910+
expected = dpt.take(x, ind, axis=axis)
1911+
1912+
dpt.take(x, ind, axis=axis, out=out)
1913+
1914+
assert dpt.all(out == expected)
1915+
1916+
1917+
@pytest.mark.parametrize(
1918+
"data_dt",
1919+
_all_dtypes,
1920+
)
1921+
@pytest.mark.parametrize("order", ["C", "F"])
1922+
def test_take_out_overlap(data_dt, order):
1923+
q = get_queue_or_skip()
1924+
skip_if_dtype_not_supported(data_dt, q)
1925+
1926+
axis = 0
1927+
x = dpt.reshape(_make_3d(data_dt, q), (9, 3), order=order)
1928+
ind = dpt.arange(2, dtype="i8", sycl_queue=q)
1929+
out = x[x.shape[axis] - ind.shape[axis] : x.shape[axis], :]
1930+
1931+
expected = dpt.take(x, ind, axis=axis)
1932+
1933+
dpt.take(x, ind, axis=axis, out=out)
1934+
1935+
assert dpt.all(out == expected)
1936+
assert dpt.all(x[x.shape[0] - ind.shape[0] : x.shape[0], :] == out)
1937+
1938+
1939+
def test_take_out_errors():
1940+
q1 = get_queue_or_skip()
1941+
q2 = get_queue_or_skip()
1942+
1943+
x = dpt.arange(10, dtype="i4", sycl_queue=q1)
1944+
ind = dpt.arange(2, dtype="i4", sycl_queue=q1)
1945+
1946+
with pytest.raises(TypeError):
1947+
dpt.take(x, ind, out=dict())
1948+
1949+
out_read_only = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q1)
1950+
out_read_only.flags["W"] = False
1951+
with pytest.raises(ValueError):
1952+
dpt.take(x, ind, out=out_read_only)
1953+
1954+
out_bad_shape = dpt.empty(0, dtype=x.dtype, sycl_queue=q1)
1955+
with pytest.raises(ValueError):
1956+
dpt.take(x, ind, out=out_bad_shape)
1957+
1958+
out_bad_dt = dpt.empty(ind.shape, dtype="i8", sycl_queue=q1)
1959+
with pytest.raises(ValueError):
1960+
dpt.take(x, ind, out=out_bad_dt)
1961+
1962+
out_bad_q = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q2)
1963+
with pytest.raises(dpctl.utils.ExecutionPlacementError):
1964+
dpt.take(x, ind, out=out_bad_q)

0 commit comments

Comments
 (0)