Skip to content

Commit 440872d

Browse files
Merge remote-tracking branch 'origin/master' into enable-operators
2 parents 3ddf51c + a1dd350 commit 440872d

File tree

4 files changed

+52
-1
lines changed

4 files changed

+52
-1
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ if (WIN32)
5454
set(_clang_prefix "/clang:")
5555
endif()
5656
set_source_files_properties(
57+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
58+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
5759
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
5860
PROPERTIES COMPILE_OPTIONS "${_clang_prefix}-fno-fast-math")
5961
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)

dpctl/tensor/_dlpack.pyx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ from .._backend cimport (
3232
DPCTLSyclDeviceRef,
3333
DPCTLSyclUSMRef,
3434
)
35-
from ._usmarray cimport usm_ndarray
35+
from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, usm_ndarray
3636

3737
from platform import system as sys_platform
3838

@@ -158,9 +158,11 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
158158
cdef int64_t *shape_strides_ptr = NULL
159159
cdef int i = 0
160160
cdef int device_id = -1
161+
cdef int flags = 0
161162
cdef char *base_ptr = NULL
162163
cdef Py_ssize_t element_offset = 0
163164
cdef Py_ssize_t byte_offset = 0
165+
cdef Py_ssize_t si = 1
164166

165167
ary_base = usm_ary.get_base()
166168
ary_sycl_queue = usm_ary.get_sycl_queue()
@@ -223,9 +225,17 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
223225
for i in range(nd):
224226
shape_strides_ptr[i] = shape_ptr[i]
225227
strides_ptr = usm_ary.get_strides()
228+
flags = usm_ary.flags_
226229
if strides_ptr:
227230
for i in range(nd):
228231
shape_strides_ptr[nd + i] = strides_ptr[i]
232+
else:
233+
if not (flags & USM_ARRAY_C_CONTIGUOUS):
234+
si = 1
235+
for i in range(0, nd):
236+
shape_strides_ptr[nd + i] = si
237+
si = si * shape_ptr[i]
238+
strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]
229239

230240
ary_dt = usm_ary.dtype
231241
ary_dtk = ary_dt.kind

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,30 @@ def test_full_dtype_inference():
14421442
assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=rdt).dtype, np.floating)
14431443

14441444

1445+
@pytest.mark.parametrize("dt", ["f2", "f4", "f8"])
1446+
def test_full_special_fp(dt):
1447+
"""See gh-1314"""
1448+
q = get_queue_or_skip()
1449+
skip_if_dtype_not_supported(dt, q)
1450+
1451+
ar = dpt.full(10, fill_value=dpt.nan)
1452+
err_msg = f"Failed for fill_value=dpt.nan and dtype {dt}"
1453+
assert dpt.isnan(ar[0]), err_msg
1454+
1455+
ar = dpt.full(10, fill_value=dpt.inf)
1456+
err_msg = f"Failed for fill_value=dpt.inf and dtype {dt}"
1457+
assert dpt.isinf(ar[0]) and dpt.greater(ar[0], 0), err_msg
1458+
1459+
ar = dpt.full(10, fill_value=-dpt.inf)
1460+
err_msg = f"Failed for fill_value=-dpt.inf and dtype {dt}"
1461+
assert dpt.isinf(ar[0]) and dpt.less(ar[0], 0), err_msg
1462+
1463+
ar = dpt.full(10, fill_value=dpt.pi)
1464+
err_msg = f"Failed for fill_value=dpt.pi and dtype {dt}"
1465+
check = abs(float(ar[0]) - dpt.pi) < 16 * dpt.finfo(ar.dtype).eps
1466+
assert check, err_msg
1467+
1468+
14451469
def test_full_fill_array():
14461470
q = get_queue_or_skip()
14471471

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,18 @@ def __dlpack__(self):
178178

179179
with pytest.raises(TypeError):
180180
dpt.from_dlpack(DummyWithMethod())
181+
182+
183+
def test_from_dlpack_fortran_contig_array_roundtripping():
184+
"""Based on examples from issue gh-1241"""
185+
n0, n1 = 3, 5
186+
try:
187+
ar1d = dpt.arange(n0 * n1, dtype="i4")
188+
except dpctl.SyclDeviceCreationError:
189+
pytest.skip("No default device available")
190+
ar2d_c = dpt.reshape(ar1d, (n0, n1), order="C")
191+
ar2d_f = dpt.asarray(ar2d_c, order="F")
192+
ar2d_r = dpt.from_dlpack(ar2d_f)
193+
194+
assert dpt.all(dpt.equal(ar2d_f, ar2d_r))
195+
assert dpt.all(dpt.equal(ar2d_c, ar2d_r))

0 commit comments

Comments
 (0)