Skip to content

Commit 33e437a

Browse files
committed
Make diff test for appended Python scalars more thorough and more efficient
Also adds a correctness check
1 parent dbed59a commit 33e437a

File tree

1 file changed

+47
-36
lines changed

1 file changed

+47
-36
lines changed

dpctl/tests/test_tensor_diff.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -159,39 +159,50 @@ def test_diff_no_op():
159159
def test_diff_prepend_append_py_scalars(sh, axis):
160160
get_queue_or_skip()
161161

162-
arrs = [
163-
dpt.ones(sh, dtype="?"),
164-
dpt.ones(sh, dtype="i4"),
165-
dpt.ones(sh, dtype="f4"),
166-
dpt.ones(sh, dtype="c8"),
167-
]
168-
169-
py_zeros = [
170-
False,
171-
0,
172-
0.0,
173-
complex(0, 0),
174-
]
175-
176-
py_ones = [
177-
True,
178-
1,
179-
1.0,
180-
complex(1, 0),
181-
]
182-
183-
for zero, one, arr in zip(py_zeros, py_ones, arrs):
184-
n = 1
185-
r = dpt.diff(arr, n=n, axis=axis, prepend=zero, append=one)
186-
assert isinstance(r, dpt.usm_ndarray)
187-
assert all(
188-
r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis
189-
)
190-
assert r.shape[axis] == arr.shape[axis] + 2 - n
191-
192-
r = dpt.diff(arr, n=n, axis=axis, prepend=zero)
193-
assert isinstance(r, dpt.usm_ndarray)
194-
assert all(
195-
r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis
196-
)
197-
assert r.shape[axis] == arr.shape[axis] + 1 - n
162+
n = 1
163+
164+
arr = dpt.ones(sh, dtype="i4")
165+
zero = 0
166+
167+
# first and last elements along axis
168+
# will be checked for correctness
169+
sl1 = [slice(None)] * arr.ndim
170+
sl1[axis] = slice(1)
171+
sl1 = tuple(sl1)
172+
173+
sl2 = [slice(None)] * arr.ndim
174+
sl2[axis] = slice(-1, None, None)
175+
sl2 = tuple(sl2)
176+
177+
r = dpt.diff(arr, axis=axis, prepend=zero, append=zero)
178+
assert isinstance(r, dpt.usm_ndarray)
179+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
180+
assert r.shape[axis] == arr.shape[axis] + 2 - n
181+
assert dpt.all(r[sl1] == 1)
182+
assert dpt.all(r[sl2] == -1)
183+
184+
r = dpt.diff(arr, axis=axis, prepend=zero)
185+
assert isinstance(r, dpt.usm_ndarray)
186+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
187+
assert r.shape[axis] == arr.shape[axis] + 1 - n
188+
assert dpt.all(r[sl1] == 1)
189+
190+
r = dpt.diff(arr, axis=axis, append=zero)
191+
assert isinstance(r, dpt.usm_ndarray)
192+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
193+
assert r.shape[axis] == arr.shape[axis] + 1 - n
194+
assert dpt.all(r[sl2] == -1)
195+
196+
r = dpt.diff(arr, axis=axis, prepend=dpt.asarray(zero), append=zero)
197+
assert isinstance(r, dpt.usm_ndarray)
198+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
199+
assert r.shape[axis] == arr.shape[axis] + 2 - n
200+
assert dpt.all(r[sl1] == 1)
201+
assert dpt.all(r[sl2] == -1)
202+
203+
r = dpt.diff(arr, axis=axis, prepend=zero, append=dpt.asarray(zero))
204+
assert isinstance(r, dpt.usm_ndarray)
205+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
206+
assert r.shape[axis] == arr.shape[axis] + 2 - n
207+
assert dpt.all(r[sl1] == 1)
208+
assert dpt.all(r[sl2] == -1)

0 commit comments

Comments
 (0)