Skip to content

Commit 69aa1ed

Browse files
Add test to exercise special case in copy from ndarray to usm_ndarray
The test performs set-item on conformably permutted ndarray and usm_ndarray Also made inputs in tests sensitive to violation (not all elements are set to one, but some are set to zero too).
1 parent c6a362d commit 69aa1ed

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,11 +1063,34 @@ def test_tofrom_numpy(shape, dtype, usm_type):
10631063
skip_if_dtype_not_supported(dtype, q)
10641064
Xusm = dpt.zeros(shape, dtype=dtype, usm_type=usm_type, sycl_queue=q)
10651065
Ynp = np.ones(shape, dtype=dtype)
1066+
Ynp[(0,) * len(shape)] = 0
10661067
ind = (slice(None, None, None),) * Ynp.ndim
10671068
Xusm[ind] = Ynp
10681069
assert np.array_equal(dpt.to_numpy(Xusm), Ynp)
10691070

10701071

1072+
@pytest.mark.parametrize(
1073+
"dtype",
1074+
_all_dtypes,
1075+
)
1076+
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
1077+
def test_tofrom_numpy_permuted(dtype, usm_type):
1078+
shape = (3, 5, 7)
1079+
perm = (1, 2, 0)
1080+
q = get_queue_or_skip()
1081+
skip_if_dtype_not_supported(dtype, q)
1082+
Xusm = dpt.permute_dims(
1083+
dpt.zeros(shape, dtype=dtype, usm_type=usm_type, sycl_queue=q), perm
1084+
)
1085+
Ynp = np.transpose(np.ones(shape, dtype=dtype), perm)
1086+
Ynp[:, ::2, ::2] = 0
1087+
ind = (slice(None, None, None),) * Ynp.ndim
1088+
# even though Xusm and Ynp are strided, simple memcpy could be done.
1089+
# This test validates that it is being done correctly
1090+
Xusm[ind] = Ynp
1091+
assert np.array_equal(dpt.to_numpy(Xusm), Ynp)
1092+
1093+
10711094
@pytest.mark.parametrize(
10721095
"dtype",
10731096
_all_dtypes,

0 commit comments

Comments
 (0)