Skip to content

Commit ea67b5a

Browse files
committed
diff input validation tests
1 parent 3496df2 commit ea67b5a

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

dpctl/tests/test_tensor_diff.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from math import prod
18+
1719
import pytest
20+
from numpy.testing import assert_raises_regex
1821

1922
import dpctl.tensor as dpt
2023
from dpctl.tensor._type_utils import _to_device_supported_dtype
2124
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
from dpctl.utils import ExecutionPlacementError
2226

2327
_all_dtypes = [
2428
"?",
@@ -206,3 +210,152 @@ def test_diff_prepend_append_py_scalars(sh, axis):
206210
assert r.shape[axis] == arr.shape[axis] + 2 - n
207211
assert dpt.all(r[sl1] == 1)
208212
assert dpt.all(r[sl2] == -1)
213+
214+
215+
def test_tensor_diff_append_prepend_arrays():
216+
get_queue_or_skip()
217+
218+
n = 1
219+
axis = 0
220+
221+
sz = 5
222+
arr = dpt.arange(sz, 2 * sz, dtype="i4")
223+
prepend = dpt.arange(sz, dtype="i4")
224+
append = dpt.arange(2 * sz, 3 * sz, dtype="i4")
225+
const_diff = 1
226+
227+
r = dpt.diff(arr, axis=axis, prepend=prepend, append=append)
228+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
229+
assert (
230+
r.shape[axis]
231+
== arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n
232+
)
233+
assert dpt.all(r == const_diff)
234+
235+
r = dpt.diff(arr, axis=axis, prepend=prepend)
236+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
237+
assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n
238+
assert dpt.all(r == const_diff)
239+
240+
r = dpt.diff(arr, axis=axis, append=append)
241+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
242+
assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n
243+
assert dpt.all(r == const_diff)
244+
245+
sh = (3, 4, 5)
246+
sz = prod(sh)
247+
arr = dpt.reshape(dpt.arange(sz, 2 * sz, dtype="i4"), sh)
248+
prepend = dpt.reshape(dpt.arange(sz, dtype="i4"), sh)
249+
append = dpt.reshape(dpt.arange(2 * sz, 3 * sz, dtype="i4"), sh)
250+
const_diff = prod(sh[axis + 1 :])
251+
252+
r = dpt.diff(arr, axis=axis, prepend=prepend, append=append)
253+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
254+
assert (
255+
r.shape[axis]
256+
== arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n
257+
)
258+
assert dpt.all(r == const_diff)
259+
260+
r = dpt.diff(arr, axis=axis, prepend=prepend)
261+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
262+
assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n
263+
assert dpt.all(r == const_diff)
264+
265+
r = dpt.diff(arr, axis=axis, append=append)
266+
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
267+
assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n
268+
assert dpt.all(r == const_diff)
269+
270+
271+
def test_diff_wrong_append_prepend_shape():
272+
get_queue_or_skip()
273+
274+
arr = dpt.ones((3, 4, 5), dtype="i4")
275+
arr_bad_sh = dpt.ones(2, dtype="i4")
276+
277+
assert_raises_regex(
278+
ValueError,
279+
"`diff` argument `prepend` with shape.*is invalid"
280+
" for first input with shape.*",
281+
dpt.diff,
282+
arr,
283+
prepend=arr_bad_sh,
284+
append=arr_bad_sh,
285+
)
286+
287+
assert_raises_regex(
288+
ValueError,
289+
"`diff` argument `append` with shape.*is invalid"
290+
" for first input with shape.*",
291+
dpt.diff,
292+
arr,
293+
prepend=arr,
294+
append=arr_bad_sh,
295+
)
296+
297+
assert_raises_regex(
298+
ValueError,
299+
"`diff` argument `prepend` with shape.*is invalid"
300+
" for first input with shape.*",
301+
dpt.diff,
302+
arr,
303+
prepend=arr_bad_sh,
304+
)
305+
306+
assert_raises_regex(
307+
ValueError,
308+
"`diff` argument `append` with shape.*is invalid"
309+
" for first input with shape.*",
310+
dpt.diff,
311+
arr,
312+
append=arr_bad_sh,
313+
)
314+
315+
316+
def test_diff_compute_follows_data():
317+
q1 = get_queue_or_skip()
318+
q2 = get_queue_or_skip()
319+
q3 = get_queue_or_skip()
320+
321+
ar1 = dpt.ones(1, dtype="i4", sycl_queue=q1)
322+
ar2 = dpt.ones(1, dtype="i4", sycl_queue=q2)
323+
ar3 = dpt.ones(1, dtype="i4", sycl_queue=q3)
324+
325+
assert_raises_regex(
326+
ExecutionPlacementError,
327+
"Execution placement can not be unambiguously inferred from input "
328+
"arguments",
329+
dpt.diff,
330+
ar1,
331+
prepend=ar2,
332+
append=ar3,
333+
)
334+
335+
assert_raises_regex(
336+
ExecutionPlacementError,
337+
"Execution placement can not be unambiguously inferred from input "
338+
"arguments",
339+
dpt.diff,
340+
ar1,
341+
prepend=ar2,
342+
)
343+
344+
assert_raises_regex(
345+
ExecutionPlacementError,
346+
"Execution placement can not be unambiguously inferred from input "
347+
"arguments",
348+
dpt.diff,
349+
ar1,
350+
append=ar2,
351+
)
352+
353+
354+
def test_diff_input_validation():
355+
bad_in = dict()
356+
assert_raises_regex(
357+
TypeError,
358+
"Expecting dpctl.tensor.usm_ndarray type, got.*",
359+
dpt.diff,
360+
bad_in,
361+
)

0 commit comments

Comments
 (0)