Skip to content

Commit 80eae6e

Browse files
committed
Corrected order="K" support in copy
1 parent a3c00bc commit 80eae6e

File tree

1 file changed

+59
-72
lines changed

1 file changed

+59
-72
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 59 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -290,78 +290,6 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
290290
_copy_same_shape(dst, src_same_shape)
291291

292292

293-
def copy(usm_ary, order="K"):
294-
"""copy(ary, order="K")
295-
296-
Creates a copy of given instance of :class:`dpctl.tensor.usm_ndarray`.
297-
298-
Args:
299-
ary (usm_ndarray):
300-
Input array.
301-
order ({"C", "F", "A", "K"}, optional):
302-
Controls the memory layout of the output array.
303-
Returns:
304-
usm_ndarray:
305-
A copy of the input array.
306-
307-
Memory layout of the copy is controlled by `order` keyword,
308-
following NumPy's conventions. The `order` keywords can be
309-
one of the following:
310-
311-
- "C": C-contiguous memory layout
312-
- "F": Fortran-contiguous memory layout
313-
- "A": Fortran-contiguous if the input array is also Fortran-contiguous,
314-
otherwise C-contiguous
315-
- "K": match the layout of `usm_ary` as closely as possible.
316-
317-
"""
318-
if not isinstance(usm_ary, dpt.usm_ndarray):
319-
return TypeError(
320-
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
321-
)
322-
copy_order = "C"
323-
if order == "C":
324-
pass
325-
elif order == "F":
326-
copy_order = order
327-
elif order == "A":
328-
if usm_ary.flags.f_contiguous:
329-
copy_order = "F"
330-
elif order == "K":
331-
if usm_ary.flags.f_contiguous:
332-
copy_order = "F"
333-
else:
334-
raise ValueError(
335-
"Unrecognized value of the order keyword. "
336-
"Recognized values are 'A', 'C', 'F', or 'K'"
337-
)
338-
c_contig = usm_ary.flags.c_contiguous
339-
f_contig = usm_ary.flags.f_contiguous
340-
R = dpt.usm_ndarray(
341-
usm_ary.shape,
342-
dtype=usm_ary.dtype,
343-
buffer=usm_ary.usm_type,
344-
order=copy_order,
345-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
346-
)
347-
if order == "K" and (not c_contig and not f_contig):
348-
original_strides = usm_ary.strides
349-
ind = sorted(
350-
range(usm_ary.ndim),
351-
key=lambda i: abs(original_strides[i]),
352-
reverse=True,
353-
)
354-
new_strides = tuple(R.strides[ind[i]] for i in ind)
355-
R = dpt.usm_ndarray(
356-
usm_ary.shape,
357-
dtype=usm_ary.dtype,
358-
buffer=R.usm_data,
359-
strides=new_strides,
360-
)
361-
_copy_same_shape(R, usm_ary)
362-
return R
363-
364-
365293
def _empty_like_orderK(X, dt, usm_type=None, dev=None):
366294
"""Returns empty array like `x`, using order='K'
367295
@@ -452,6 +380,65 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
452380
return dpt.permute_dims(R, inv_perm)
453381

454382

383+
def copy(usm_ary, order="K"):
384+
"""copy(ary, order="K")
385+
386+
Creates a copy of given instance of :class:`dpctl.tensor.usm_ndarray`.
387+
388+
Args:
389+
ary (usm_ndarray):
390+
Input array.
391+
order ({"C", "F", "A", "K"}, optional):
392+
Controls the memory layout of the output array.
393+
Returns:
394+
usm_ndarray:
395+
A copy of the input array.
396+
397+
Memory layout of the copy is controlled by `order` keyword,
398+
following NumPy's conventions. The `order` keywords can be
399+
one of the following:
400+
401+
- "C": C-contiguous memory layout
402+
- "F": Fortran-contiguous memory layout
403+
- "A": Fortran-contiguous if the input array is also Fortran-contiguous,
404+
otherwise C-contiguous
405+
- "K": match the layout of `usm_ary` as closely as possible.
406+
407+
"""
408+
if not isinstance(usm_ary, dpt.usm_ndarray):
409+
return TypeError(
410+
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
411+
)
412+
copy_order = "C"
413+
if order == "C":
414+
pass
415+
elif order == "F":
416+
copy_order = order
417+
elif order == "A":
418+
if usm_ary.flags.f_contiguous:
419+
copy_order = "F"
420+
elif order == "K":
421+
if usm_ary.flags.f_contiguous:
422+
copy_order = "F"
423+
else:
424+
raise ValueError(
425+
"Unrecognized value of the order keyword. "
426+
"Recognized values are 'A', 'C', 'F', or 'K'"
427+
)
428+
if order == "K":
429+
R = _empty_like_orderK(usm_ary, usm_ary.dtype)
430+
else:
431+
R = dpt.usm_ndarray(
432+
usm_ary.shape,
433+
dtype=usm_ary.dtype,
434+
buffer=usm_ary.usm_type,
435+
order=copy_order,
436+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
437+
)
438+
_copy_same_shape(R, usm_ary)
439+
return R
440+
441+
455442
def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
456443
""" astype(array, new_dtype, order="K", casting="unsafe", \
457444
copy=True)

0 commit comments

Comments
 (0)