Skip to content

Commit 0e69998

Browse files
Add tests of internal functions to improve coveage of _copy_utils
1 parent a9a261e commit 0e69998

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,3 +1592,94 @@ def test_take_along_axis_validation():
15921592
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
15931593
with pytest.raises(ExecutionPlacementError):
15941594
dpt.take_along_axis(x, ind2)
1595+
1596+
1597+
def check__extract_impl_validation(fn):
1598+
x = dpt.ones(10)
1599+
ind = dpt.ones(10, dtype="?")
1600+
with pytest.raises(TypeError):
1601+
fn(list(), ind)
1602+
with pytest.raises(TypeError):
1603+
fn(x, list())
1604+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1605+
ind2 = dpt.ones(10, dtype="?", sycl_queue=q2)
1606+
with pytest.raises(ExecutionPlacementError):
1607+
fn(x, ind2)
1608+
with pytest.raises(ValueError):
1609+
fn(x, ind, 1)
1610+
1611+
1612+
def check__nonzero_impl_validation(fn):
1613+
with pytest.raises(TypeError):
1614+
fn(list())
1615+
1616+
1617+
def check__take_multi_index(fn):
1618+
x = dpt.ones(10)
1619+
x_dev = x.sycl_device
1620+
info_ = dpt.__array_namespace_info__()
1621+
def_dtypes = info_.default_dtypes(device=x_dev)
1622+
ind_dt = def_dtypes["indexing"]
1623+
ind = dpt.arange(10, dtype=ind_dt)
1624+
with pytest.raises(TypeError):
1625+
fn(list(), tuple(), 1)
1626+
with pytest.raises(ValueError):
1627+
fn(x, (ind,), 0, mode=2)
1628+
with pytest.raises(ValueError):
1629+
fn(x, (None,), 1)
1630+
with pytest.raises(IndexError):
1631+
fn(x, (x,), 1)
1632+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1633+
ind2 = dpt.arange(10, dtype=ind_dt, sycl_queue=q2)
1634+
with pytest.raises(ExecutionPlacementError):
1635+
fn(x, (ind2,), 0)
1636+
m = dpt.ones((10, 10))
1637+
ind_1 = dpt.arange(10, dtype="i8")
1638+
ind_2 = dpt.arange(10, dtype="u8")
1639+
with pytest.raises(ValueError):
1640+
fn(m, (ind_1, ind_2), 0)
1641+
1642+
1643+
def check__place_impl_validation(fn):
1644+
with pytest.raises(TypeError):
1645+
fn(list(), list(), list())
1646+
x = dpt.ones(10)
1647+
with pytest.raises(TypeError):
1648+
fn(x, list(), list())
1649+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1650+
mask2 = dpt.ones(10, dtype="?", sycl_queue=q2)
1651+
with pytest.raises(ExecutionPlacementError):
1652+
fn(x, mask2, 1)
1653+
x2 = dpt.ones((5, 5))
1654+
mask2 = dpt.ones((5, 5), dtype="?")
1655+
with pytest.raises(ValueError):
1656+
fn(x2, mask2, x2, axis=1)
1657+
1658+
1659+
def check__put_multi_index_validation(fn):
1660+
with pytest.raises(TypeError):
1661+
fn(list(), list(), 0, list())
1662+
x = dpt.ones(10)
1663+
inds = dpt.arange(10, dtype="i8")
1664+
vals = dpt.zeros(10)
1665+
# test inds which is not a tuple/list
1666+
fn(x, inds, 0, vals)
1667+
x2 = dpt.ones((5, 5))
1668+
ind1 = dpt.arange(5, dtype="i8")
1669+
ind2 = dpt.arange(5, dtype="u8")
1670+
with pytest.raises(ValueError):
1671+
fn(x2, (ind1, ind2), 0, x2)
1672+
with pytest.raises(TypeError):
1673+
fn(x2, (ind1, list()), 0, x2)
1674+
1675+
1676+
def test__copy_utils():
1677+
import dpctl.tensor._copy_utils as cu
1678+
1679+
get_queue_or_skip()
1680+
1681+
check__extract_impl_validation(cu._extract_impl)
1682+
check__nonzero_impl_validation(cu._nonzero_impl)
1683+
check__take_multi_index(cu._take_multi_index)
1684+
check__place_impl_validation(cu._place_impl)
1685+
check__put_multi_index_validation(cu._put_multi_index)

0 commit comments

Comments
 (0)