Skip to content

Fix bad order=K code logic in tensor.asarray #1351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import dpctl.utils
from dpctl.tensor._ctors import _get_dtype
from dpctl.tensor._data_types import _get_dtype
from dpctl.tensor._device import normalize_queue_device

__doc__ = (
Expand Down Expand Up @@ -354,11 +354,11 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
)
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
st_sorted = [st[i] for i in perm]
sh = X.shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if min(st_sorted) < 0:
if min(st) < 0:
st_sorted = [st[i] for i in perm]
sl = tuple(
slice(None, None, -1)
if st_sorted[i] < 0
Expand Down
43 changes: 3 additions & 40 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import dpctl.utils
from dpctl.tensor._copy_utils import _empty_like_orderK
from dpctl.tensor._data_types import _get_dtype
from dpctl.tensor._device import normalize_queue_device
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol

Expand All @@ -32,24 +34,6 @@
_host_set = frozenset([None])


def _get_dtype(dtype, sycl_obj, ref_type=None):
if dtype is None:
if ref_type in [None, float] or np.issubdtype(ref_type, np.floating):
dtype = ti.default_device_fp_type(sycl_obj)
return dpt.dtype(dtype)
if ref_type in [bool, np.bool_]:
dtype = ti.default_device_bool_type(sycl_obj)
return dpt.dtype(dtype)
if ref_type is int or np.issubdtype(ref_type, np.integer):
dtype = ti.default_device_int_type(sycl_obj)
return dpt.dtype(dtype)
if ref_type is complex or np.issubdtype(ref_type, np.complexfloating):
dtype = ti.default_device_complex_type(sycl_obj)
return dpt.dtype(dtype)
raise TypeError(f"Reference type {ref_type} not recognized.")
return dpt.dtype(dtype)


def _array_info_dispatch(obj):
if isinstance(obj, dpt.usm_ndarray):
return obj.shape, obj.dtype, frozenset([obj.sycl_queue])
Expand Down Expand Up @@ -162,28 +146,7 @@ def _asarray_from_usm_ndarray(
order = "C" if c_contig else "F"
if order == "K":
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
# new USM allocation
res = dpt.usm_ndarray(
usm_ndary.shape,
dtype=dtype,
buffer=usm_type,
order="C",
buffer_ctor_kwargs={"queue": copy_q},
)
original_strides = usm_ndary.strides
ind = sorted(
range(usm_ndary.ndim),
key=lambda i: abs(original_strides[i]),
reverse=True,
)
new_strides = tuple(res.strides[ind[i]] for i in ind)
# reuse previously made USM allocation
res = dpt.usm_ndarray(
usm_ndary.shape,
dtype=res.dtype,
buffer=res.usm_data,
strides=new_strides,
)
res = _empty_like_orderK(usm_ndary, dtype, usm_type, copy_q)
else:
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
Expand Down
44 changes: 44 additions & 0 deletions dpctl/tensor/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from numpy import bool_ as np_bool_
from numpy import complexfloating as np_complexfloating
from numpy import dtype
from numpy import floating as np_floating
from numpy import integer as np_integer
from numpy import issubdtype as np_issubdtype

from dpctl.tensor._tensor_impl import (
default_device_bool_type as ti_default_device_bool_type,
)
from dpctl.tensor._tensor_impl import (
default_device_complex_type as ti_default_device_complex_type,
)
from dpctl.tensor._tensor_impl import (
default_device_fp_type as ti_default_device_fp_type,
)
from dpctl.tensor._tensor_impl import (
default_device_int_type as ti_default_device_int_type,
)

bool = dtype("bool")
int8 = dtype("int8")
Expand Down Expand Up @@ -74,6 +92,32 @@ def isdtype(dtype_, kind):
raise TypeError(f"Unsupported data type kind: {kind}")


def _get_dtype(inp_dt, sycl_obj, ref_type=None):
"""
Type inference utility to construct data type
object with defaults based on reference type.

_get_dtype is used by dpctl.tensor.asarray
to infer data type of the output array from the
input sequence.
"""
if inp_dt is None:
if ref_type in [None, float] or np_issubdtype(ref_type, np_floating):
fp_dt = ti_default_device_fp_type(sycl_obj)
return dtype(fp_dt)
if ref_type in [bool, np_bool_]:
bool_dt = ti_default_device_bool_type(sycl_obj)
return dtype(bool_dt)
if ref_type is int or np_issubdtype(ref_type, np_integer):
int_dt = ti_default_device_int_type(sycl_obj)
return dtype(int_dt)
if ref_type is complex or np_issubdtype(ref_type, np_complexfloating):
cfp_dt = ti_default_device_complex_type(sycl_obj)
return dtype(cfp_dt)
raise TypeError(f"Reference type {ref_type} not recognized.")
return dtype(inp_dt)


__all__ = [
"dtype",
"isdtype",
Expand Down
12 changes: 12 additions & 0 deletions dpctl/tests/test_tensor_asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,15 @@ def test_ulonglong_gh_1167():
assert x.dtype == dpt.uint64
x = dpt.asarray(9223372036854775808, dtype="u8")
assert x.dtype == dpt.uint64


def test_orderK_gh_1350():
get_queue_or_skip()
a = dpt.empty((2, 3, 4), dtype="u1")
b = dpt.permute_dims(a, (2, 0, 1))
c = dpt.asarray(b, copy=True, order="K")

assert c.shape == b.shape
assert c.strides == b.strides
assert c._element_offset == 0
assert not c._pointer == b._pointer