Skip to content

Commit 57b7c87

Browse files
committed
Adds tests for indexing array casting for indices and values
1 parent 9dbb0b6 commit 57b7c87

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,3 +1433,107 @@ def test_nonzero_dtype():
14331433
index_dt = dpt.dtype(ti.default_device_index_type(x.sycl_queue))
14341434
assert idx.dtype == index_dt
14351435
assert idy.dtype == index_dt
1436+
1437+
1438+
def test_take_empty_axes():
1439+
get_queue_or_skip()
1440+
1441+
x = dpt.ones((3, 0, 4, 5, 6), dtype="f4")
1442+
inds = dpt.ones(1, dtype="i4")
1443+
1444+
with pytest.raises(IndexError):
1445+
dpt.take(x, inds, axis=1)
1446+
1447+
inds = dpt.ones(0, dtype="i4")
1448+
r = dpt.take(x, inds, axis=1)
1449+
assert r.shape == x.shape
1450+
1451+
1452+
def test_put_empty_axes():
1453+
get_queue_or_skip()
1454+
1455+
x = dpt.ones((3, 0, 4, 5, 6), dtype="f4")
1456+
inds = dpt.ones(1, dtype="i4")
1457+
vals = dpt.zeros((3, 1, 4, 5, 6), dtype="f4")
1458+
1459+
with pytest.raises(IndexError):
1460+
dpt.put(x, inds, vals, axis=1)
1461+
1462+
inds = dpt.ones(0, dtype="i4")
1463+
vals = dpt.zeros_like(x)
1464+
1465+
with pytest.raises(ValueError):
1466+
dpt.put(x, inds, vals, axis=1)
1467+
1468+
1469+
def test_put_cast_vals():
1470+
get_queue_or_skip()
1471+
1472+
x = dpt.arange(10, dtype="i4")
1473+
inds = dpt.arange(7, 10, dtype="i4")
1474+
vals = dpt.zeros_like(inds, dtype="f4")
1475+
1476+
dpt.put(x, inds, vals)
1477+
assert dpt.all(x[7:10] == 0)
1478+
1479+
1480+
def test_advanced_integer_indexing_cast_vals():
1481+
get_queue_or_skip()
1482+
1483+
x = dpt.arange(10, dtype="i4")
1484+
inds = dpt.arange(7, 10, dtype="i4")
1485+
vals = dpt.zeros_like(inds, dtype="f4")
1486+
1487+
x[inds] = vals
1488+
assert dpt.all(x[7:10] == 0)
1489+
1490+
1491+
def test_advanced_integer_indexing_empty_axis():
1492+
get_queue_or_skip()
1493+
1494+
# getting
1495+
x = dpt.ones((3, 0, 4, 5, 6), dtype="f4")
1496+
inds = dpt.ones(1, dtype="i4")
1497+
with pytest.raises(IndexError):
1498+
x[:, inds, ...]
1499+
with pytest.raises(IndexError):
1500+
x[inds, inds, inds, ...]
1501+
1502+
# setting
1503+
with pytest.raises(IndexError):
1504+
x[:, inds, ...] = 2
1505+
with pytest.raises(IndexError):
1506+
x[inds, inds, inds, ...] = 2
1507+
1508+
# empty inds
1509+
inds = dpt.ones(0, dtype="i4")
1510+
assert x[:, inds, ...].shape == x.shape
1511+
assert x[inds, inds, inds, ...].shape == (0, 5, 6)
1512+
1513+
vals = dpt.zeros_like(x)
1514+
x[:, inds, ...] = vals
1515+
vals = dpt.zeros((0, 5, 6), dtype="f4")
1516+
x[inds, inds, inds, ...] = vals
1517+
1518+
1519+
def test_advanced_integer_indexing_cast_indices():
1520+
get_queue_or_skip()
1521+
1522+
inds0 = dpt.asarray([0, 1], dtype="i1")
1523+
for ind_dts in (("i1", "i2", "i4"), ("i1", "u4", "i4"), ("u1", "u2", "u8")):
1524+
x = dpt.ones((3, 4, 5, 6), dtype="i4")
1525+
inds0 = dpt.asarray([0, 1], dtype=ind_dts[0])
1526+
inds1 = dpt.astype(inds0, ind_dts[1])
1527+
x[inds0, inds1, ...] = 2
1528+
assert dpt.all(x[inds0, inds1, ...] == 2)
1529+
inds2 = dpt.astype(inds0, ind_dts[2])
1530+
x[inds0, inds1, ...] = 2
1531+
assert dpt.all(x[inds0, inds1, inds2, ...] == 2)
1532+
1533+
# fail when float would be required per type promotion
1534+
inds0 = dpt.asarray([0, 1], dtype="i1")
1535+
inds1 = dpt.astype(inds0, "u4")
1536+
inds2 = dpt.astype(inds0, "u8")
1537+
x = dpt.ones((3, 4, 5, 6), dtype="i4")
1538+
with pytest.raises(ValueError):
1539+
x[inds0, inds1, inds2, ...]

0 commit comments

Comments
 (0)