Skip to content

Fix gh 1241 #1319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ from .._backend cimport (
DPCTLSyclDeviceRef,
DPCTLSyclUSMRef,
)
from ._usmarray cimport usm_ndarray
from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, usm_ndarray

from platform import system as sys_platform

Expand Down Expand Up @@ -158,9 +158,11 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
cdef int64_t *shape_strides_ptr = NULL
cdef int i = 0
cdef int device_id = -1
cdef int flags = 0
cdef char *base_ptr = NULL
cdef Py_ssize_t element_offset = 0
cdef Py_ssize_t byte_offset = 0
cdef Py_ssize_t si = 1

ary_base = usm_ary.get_base()
ary_sycl_queue = usm_ary.get_sycl_queue()
Expand Down Expand Up @@ -223,9 +225,17 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
for i in range(nd):
shape_strides_ptr[i] = shape_ptr[i]
strides_ptr = usm_ary.get_strides()
flags = usm_ary.flags_
if strides_ptr:
for i in range(nd):
shape_strides_ptr[nd + i] = strides_ptr[i]
else:
if not (flags & USM_ARRAY_C_CONTIGUOUS):
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]

ary_dt = usm_ary.dtype
ary_dtk = ary_dt.kind
Expand Down
15 changes: 15 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,18 @@ def __dlpack__(self):

with pytest.raises(TypeError):
dpt.from_dlpack(DummyWithMethod())


def test_from_dlpack_fortran_contig_array_roundtripping():
"""Based on examples from issue gh-1241"""
n0, n1 = 3, 5
try:
ar1d = dpt.arange(n0 * n1, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")
ar2d_c = dpt.reshape(ar1d, (n0, n1), order="C")
ar2d_f = dpt.asarray(ar2d_c, order="F")
ar2d_r = dpt.from_dlpack(ar2d_f)

assert dpt.all(dpt.equal(ar2d_f, ar2d_r))
assert dpt.all(dpt.equal(ar2d_c, ar2d_r))