|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +from math import prod |
| 18 | + |
17 | 19 | import pytest
|
| 20 | +from numpy.testing import assert_raises_regex |
18 | 21 |
|
19 | 22 | import dpctl.tensor as dpt
|
20 | 23 | from dpctl.tensor._type_utils import _to_device_supported_dtype
|
21 | 24 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
|
| 25 | +from dpctl.utils import ExecutionPlacementError |
22 | 26 |
|
23 | 27 | _all_dtypes = [
|
24 | 28 | "?",
|
@@ -206,3 +210,152 @@ def test_diff_prepend_append_py_scalars(sh, axis):
|
206 | 210 | assert r.shape[axis] == arr.shape[axis] + 2 - n
|
207 | 211 | assert dpt.all(r[sl1] == 1)
|
208 | 212 | assert dpt.all(r[sl2] == -1)
|
| 213 | + |
| 214 | + |
| 215 | +def test_tensor_diff_append_prepend_arrays(): |
| 216 | + get_queue_or_skip() |
| 217 | + |
| 218 | + n = 1 |
| 219 | + axis = 0 |
| 220 | + |
| 221 | + sz = 5 |
| 222 | + arr = dpt.arange(sz, 2 * sz, dtype="i4") |
| 223 | + prepend = dpt.arange(sz, dtype="i4") |
| 224 | + append = dpt.arange(2 * sz, 3 * sz, dtype="i4") |
| 225 | + const_diff = 1 |
| 226 | + |
| 227 | + r = dpt.diff(arr, axis=axis, prepend=prepend, append=append) |
| 228 | + assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis) |
| 229 | + assert ( |
| 230 | + r.shape[axis] |
| 231 | + == arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n |
| 232 | + ) |
| 233 | + assert dpt.all(r == const_diff) |
| 234 | + |
| 235 | + r = dpt.diff(arr, axis=axis, prepend=prepend) |
| 236 | + assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis) |
| 237 | + assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n |
| 238 | + assert dpt.all(r == const_diff) |
| 239 | + |
| 240 | + r = dpt.diff(arr, axis=axis, append=append) |
| 241 | + assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis) |
| 242 | + assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n |
| 243 | + assert dpt.all(r == const_diff) |
| 244 | + |
| 245 | + sh = (3, 4, 5) |
| 246 | + sz = prod(sh) |
| 247 | + arr = dpt.reshape(dpt.arange(sz, 2 * sz, dtype="i4"), sh) |
| 248 | + prepend = dpt.reshape(dpt.arange(sz, dtype="i4"), sh) |
| 249 | + append = dpt.reshape(dpt.arange(2 * sz, 3 * sz, dtype="i4"), sh) |
| 250 | + const_diff = prod(sh[axis + 1 :]) |
| 251 | + |
| 252 | + r = dpt.diff(arr, axis=axis, prepend=prepend, append=append) |
| 253 | + assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis) |
| 254 | + assert ( |
| 255 | + r.shape[axis] |
| 256 | + == arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n |
| 257 | + ) |
| 258 | + assert dpt.all(r == const_diff) |
| 259 | + |
| 260 | + r = dpt.diff(arr, axis=axis, prepend=prepend) |
| 261 | + assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis) |
| 262 | + assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n |
| 263 | + assert dpt.all(r == const_diff) |
| 264 | + |
| 265 | + r = dpt.diff(arr, axis=axis, append=append) |
| 266 | + assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis) |
| 267 | + assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n |
| 268 | + assert dpt.all(r == const_diff) |
| 269 | + |
| 270 | + |
| 271 | +def test_diff_wrong_append_prepend_shape(): |
| 272 | + get_queue_or_skip() |
| 273 | + |
| 274 | + arr = dpt.ones((3, 4, 5), dtype="i4") |
| 275 | + arr_bad_sh = dpt.ones(2, dtype="i4") |
| 276 | + |
| 277 | + assert_raises_regex( |
| 278 | + ValueError, |
| 279 | + "`diff` argument `prepend` with shape.*is invalid" |
| 280 | + " for first input with shape.*", |
| 281 | + dpt.diff, |
| 282 | + arr, |
| 283 | + prepend=arr_bad_sh, |
| 284 | + append=arr_bad_sh, |
| 285 | + ) |
| 286 | + |
| 287 | + assert_raises_regex( |
| 288 | + ValueError, |
| 289 | + "`diff` argument `append` with shape.*is invalid" |
| 290 | + " for first input with shape.*", |
| 291 | + dpt.diff, |
| 292 | + arr, |
| 293 | + prepend=arr, |
| 294 | + append=arr_bad_sh, |
| 295 | + ) |
| 296 | + |
| 297 | + assert_raises_regex( |
| 298 | + ValueError, |
| 299 | + "`diff` argument `prepend` with shape.*is invalid" |
| 300 | + " for first input with shape.*", |
| 301 | + dpt.diff, |
| 302 | + arr, |
| 303 | + prepend=arr_bad_sh, |
| 304 | + ) |
| 305 | + |
| 306 | + assert_raises_regex( |
| 307 | + ValueError, |
| 308 | + "`diff` argument `append` with shape.*is invalid" |
| 309 | + " for first input with shape.*", |
| 310 | + dpt.diff, |
| 311 | + arr, |
| 312 | + append=arr_bad_sh, |
| 313 | + ) |
| 314 | + |
| 315 | + |
| 316 | +def test_diff_compute_follows_data(): |
| 317 | + q1 = get_queue_or_skip() |
| 318 | + q2 = get_queue_or_skip() |
| 319 | + q3 = get_queue_or_skip() |
| 320 | + |
| 321 | + ar1 = dpt.ones(1, dtype="i4", sycl_queue=q1) |
| 322 | + ar2 = dpt.ones(1, dtype="i4", sycl_queue=q2) |
| 323 | + ar3 = dpt.ones(1, dtype="i4", sycl_queue=q3) |
| 324 | + |
| 325 | + assert_raises_regex( |
| 326 | + ExecutionPlacementError, |
| 327 | + "Execution placement can not be unambiguously inferred from input " |
| 328 | + "arguments", |
| 329 | + dpt.diff, |
| 330 | + ar1, |
| 331 | + prepend=ar2, |
| 332 | + append=ar3, |
| 333 | + ) |
| 334 | + |
| 335 | + assert_raises_regex( |
| 336 | + ExecutionPlacementError, |
| 337 | + "Execution placement can not be unambiguously inferred from input " |
| 338 | + "arguments", |
| 339 | + dpt.diff, |
| 340 | + ar1, |
| 341 | + prepend=ar2, |
| 342 | + ) |
| 343 | + |
| 344 | + assert_raises_regex( |
| 345 | + ExecutionPlacementError, |
| 346 | + "Execution placement can not be unambiguously inferred from input " |
| 347 | + "arguments", |
| 348 | + dpt.diff, |
| 349 | + ar1, |
| 350 | + append=ar2, |
| 351 | + ) |
| 352 | + |
| 353 | + |
| 354 | +def test_diff_input_validation(): |
| 355 | + bad_in = dict() |
| 356 | + assert_raises_regex( |
| 357 | + TypeError, |
| 358 | + "Expecting dpctl.tensor.usm_ndarray type, got.*", |
| 359 | + dpt.diff, |
| 360 | + bad_in, |
| 361 | + ) |
0 commit comments