|
13 | 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
| 16 | +import builtins |
16 | 17 | import operator
|
17 | 18 |
|
18 | 19 | import numpy as np
|
@@ -361,6 +362,96 @@ def copy(usm_ary, order="K"):
|
361 | 362 | return R
|
362 | 363 |
|
363 | 364 |
|
| 365 | +def _empty_like_orderK(X, dt, usm_type=None, dev=None): |
| 366 | + """Returns empty array like `x`, using order='K' |
| 367 | +
|
| 368 | + For an array `x` that was obtained by permutation of a contiguous |
| 369 | + array the returned array will have the same shape and the same |
| 370 | + strides as `x`. |
| 371 | + """ |
| 372 | + if not isinstance(X, dpt.usm_ndarray): |
| 373 | + raise TypeError(f"Expected usm_ndarray, got {type(X)}") |
| 374 | + if usm_type is None: |
| 375 | + usm_type = X.usm_type |
| 376 | + if dev is None: |
| 377 | + dev = X.device |
| 378 | + fl = X.flags |
| 379 | + if fl["C"] or X.size <= 1: |
| 380 | + return dpt.empty_like( |
| 381 | + X, dtype=dt, usm_type=usm_type, device=dev, order="C" |
| 382 | + ) |
| 383 | + elif fl["F"]: |
| 384 | + return dpt.empty_like( |
| 385 | + X, dtype=dt, usm_type=usm_type, device=dev, order="F" |
| 386 | + ) |
| 387 | + st = list(X.strides) |
| 388 | + perm = sorted( |
| 389 | + range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True |
| 390 | + ) |
| 391 | + inv_perm = sorted(range(X.ndim), key=lambda i: perm[i]) |
| 392 | + st_sorted = [st[i] for i in perm] |
| 393 | + sh = X.shape |
| 394 | + sh_sorted = tuple(sh[i] for i in perm) |
| 395 | + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") |
| 396 | + if min(st_sorted) < 0: |
| 397 | + sl = tuple( |
| 398 | + slice(None, None, -1) |
| 399 | + if st_sorted[i] < 0 |
| 400 | + else slice(None, None, None) |
| 401 | + for i in range(X.ndim) |
| 402 | + ) |
| 403 | + R = R[sl] |
| 404 | + return dpt.permute_dims(R, inv_perm) |
| 405 | + |
| 406 | + |
| 407 | +def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev): |
| 408 | + if not isinstance(X1, dpt.usm_ndarray): |
| 409 | + raise TypeError(f"Expected usm_ndarray, got {type(X1)}") |
| 410 | + if not isinstance(X2, dpt.usm_ndarray): |
| 411 | + raise TypeError(f"Expected usm_ndarray, got {type(X2)}") |
| 412 | + nd1 = X1.ndim |
| 413 | + nd2 = X2.ndim |
| 414 | + if nd1 > nd2 and X1.shape == res_shape: |
| 415 | + return _empty_like_orderK(X1, dt, usm_type, dev) |
| 416 | + elif nd1 < nd2 and X2.shape == res_shape: |
| 417 | + return _empty_like_orderK(X2, dt, usm_type, dev) |
| 418 | + fl1 = X1.flags |
| 419 | + fl2 = X2.flags |
| 420 | + if fl1["C"] or fl2["C"]: |
| 421 | + return dpt.empty( |
| 422 | + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C" |
| 423 | + ) |
| 424 | + if fl1["F"] and fl2["F"]: |
| 425 | + return dpt.empty( |
| 426 | + res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F" |
| 427 | + ) |
| 428 | + st1 = list(X1.strides) |
| 429 | + st2 = list(X2.strides) |
| 430 | + max_ndim = max(nd1, nd2) |
| 431 | + st1 += [0] * (max_ndim - len(st1)) |
| 432 | + st2 += [0] * (max_ndim - len(st2)) |
| 433 | + perm = sorted( |
| 434 | + range(max_ndim), |
| 435 | + key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])), |
| 436 | + reverse=True, |
| 437 | + ) |
| 438 | + inv_perm = sorted(range(max_ndim), key=lambda i: perm[i]) |
| 439 | + st1_sorted = [st1[i] for i in perm] |
| 440 | + st2_sorted = [st2[i] for i in perm] |
| 441 | + sh = res_shape |
| 442 | + sh_sorted = tuple(sh[i] for i in perm) |
| 443 | + R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C") |
| 444 | + if max(min(st1_sorted), min(st2_sorted)) < 0: |
| 445 | + sl = tuple( |
| 446 | + slice(None, None, -1) |
| 447 | + if (st1_sorted[i] < 0 and st2_sorted[i] < 0) |
| 448 | + else slice(None, None, None) |
| 449 | + for i in range(nd1) |
| 450 | + ) |
| 451 | + R = R[sl] |
| 452 | + return dpt.permute_dims(R, inv_perm) |
| 453 | + |
| 454 | + |
364 | 455 | def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
|
365 | 456 | """ astype(array, new_dtype, order="K", casting="unsafe", \
|
366 | 457 | copy=True)
|
@@ -432,26 +523,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
|
432 | 523 | "Unrecognized value of the order keyword. "
|
433 | 524 | "Recognized values are 'A', 'C', 'F', or 'K'"
|
434 | 525 | )
|
435 |
| - R = dpt.usm_ndarray( |
436 |
| - usm_ary.shape, |
437 |
| - dtype=target_dtype, |
438 |
| - buffer=usm_ary.usm_type, |
439 |
| - order=copy_order, |
440 |
| - buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
441 |
| - ) |
442 |
| - if order == "K" and (not c_contig and not f_contig): |
443 |
| - original_strides = usm_ary.strides |
444 |
| - ind = sorted( |
445 |
| - range(usm_ary.ndim), |
446 |
| - key=lambda i: abs(original_strides[i]), |
447 |
| - reverse=True, |
448 |
| - ) |
449 |
| - new_strides = tuple(R.strides[ind[i]] for i in ind) |
| 526 | + if order == "K": |
| 527 | + R = _empty_like_orderK(usm_ary, target_dtype) |
| 528 | + else: |
450 | 529 | R = dpt.usm_ndarray(
|
451 | 530 | usm_ary.shape,
|
452 | 531 | dtype=target_dtype,
|
453 |
| - buffer=R.usm_data, |
454 |
| - strides=new_strides, |
| 532 | + buffer=usm_ary.usm_type, |
| 533 | + order=copy_order, |
| 534 | + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
455 | 535 | )
|
456 | 536 | _copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
|
457 | 537 | return R
|
|
0 commit comments