Skip to content

Commit 9629b43

Browse files
Add tests of internal functions to improve coveage of _copy_utils
1 parent a9a261e commit 9629b43

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,3 +1592,91 @@ def test_take_along_axis_validation():
15921592
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
15931593
with pytest.raises(ExecutionPlacementError):
15941594
dpt.take_along_axis(x, ind2)
1595+
1596+
1597+
def check__extract_impl_validation(fn):
1598+
x = dpt.ones(10)
1599+
ind = dpt.ones(10, dtype="?")
1600+
with pytest.raises(TypeError):
1601+
fn(list(), ind)
1602+
with pytest.raises(TypeError):
1603+
fn(x, list())
1604+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1605+
ind2 = dpt.ones(10, dtype="?", sycl_queue=q2)
1606+
with pytest.raises(ExecutionPlacementError):
1607+
fn(x, ind2)
1608+
with pytest.raises(ValueError):
1609+
fn(x, ind, 1)
1610+
1611+
1612+
def check__nonzero_impl_validation(fn):
1613+
with pytest.raises(TypeError):
1614+
fn(list())
1615+
1616+
1617+
def check__take_multi_index(fn):
1618+
x = dpt.ones(10)
1619+
x_dev = x.sycl_device
1620+
info_ = dpt.__array_namespace_info__()
1621+
def_dtypes = info_.default_dtypes(device=x_dev)
1622+
ind_dt = def_dtypes["indexing"]
1623+
ind = dpt.arange(10, dtype=ind_dt)
1624+
with pytest.raises(TypeError):
1625+
fn(list(), tuple(), 1)
1626+
with pytest.raises(ValueError):
1627+
fn(x, (ind,), 0, mode=2)
1628+
with pytest.raises(ValueError):
1629+
fn(x, (None,), 1)
1630+
with pytest.raises(IndexError):
1631+
fn(x, (x,), 1)
1632+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1633+
ind2 = dpt.arange(10, dtype=ind_dt, sycl_queue=q2)
1634+
with pytest.raises(ExecutionPlacementError):
1635+
fn(x, (ind2,), 0)
1636+
m = dpt.ones((10, 10))
1637+
ind_1 = dpt.arange(10, dtype="i8")
1638+
ind_2 = dpt.arange(10, dtype="u8")
1639+
with pytest.raises(ValueError):
1640+
fn(m, (ind_1, ind_2), 0)
1641+
1642+
1643+
def check__place_impl_validation(fn):
1644+
with pytest.raises(TypeError):
1645+
fn(list(), list(), list())
1646+
x = dpt.ones(10)
1647+
with pytest.raises(TypeError):
1648+
fn(x, list(), list())
1649+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1650+
mask2 = dpt.ones(10, dtype="?", sycl_queue=q2)
1651+
with pytest.raises(ExecutionPlacementError):
1652+
fn(x, mask2, 1)
1653+
mask = dpt.ones(x.shape, dtype="?")
1654+
with pytest.raises(ValueError):
1655+
fn(x, mask, x, 1)
1656+
1657+
1658+
def check__put_multi_index_validation(fn):
1659+
with pytest.raises(TypeError):
1660+
fn(list(), list(), 0, list())
1661+
x = dpt.ones(10)
1662+
inds = dpt.arange(10, dtype="i8")
1663+
vals = dpt.zeros(10)
1664+
# test inds which is not a tuple/list
1665+
fn(x, inds, 0, vals)
1666+
x2 = dpt.ones((5, 5))
1667+
ind1 = dpt.arange(5, dtype="i8")
1668+
ind2 = dpt.arange(5, dtype="u8")
1669+
with pytest.raises(ValueError):
1670+
fn(x2, (ind1, ind2), 0, x2)
1671+
1672+
1673+
def test__copy_utils():
1674+
import dpctl.tensor._copy_utils as cu
1675+
1676+
get_queue_or_skip()
1677+
1678+
check__extract_impl_validation(cu._extract_impl)
1679+
check__nonzero_impl_validation(cu._nonzero_impl)
1680+
check__take_multi_index(cu._take_multi_index)
1681+
check__place_impl_validation(cu._place_impl)
1682+
check__put_multi_index_validation(cu._put_multi_index)

0 commit comments

Comments
 (0)