Skip to content

divide and comparisons allow a greater range of Python integer and integer array combinations #1771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_acceptance_fn_negative,
_acceptance_fn_reciprocal,
_acceptance_fn_subtract,
_resolve_weak_types_comparisons,
_resolve_weak_types_all_py_ints,
)

# U01: ==== ABS (x)
Expand Down Expand Up @@ -661,6 +661,7 @@
_divide_docstring_,
binary_inplace_fn=ti._divide_inplace,
acceptance_fn=_acceptance_fn_divide,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _divide_docstring_

Expand Down Expand Up @@ -695,7 +696,7 @@
ti._equal_result_type,
ti._equal,
_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _equal_docstring_

Expand Down Expand Up @@ -854,7 +855,7 @@
ti._greater_result_type,
ti._greater,
_greater_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _greater_docstring_

Expand Down Expand Up @@ -890,7 +891,7 @@
ti._greater_equal_result_type,
ti._greater_equal,
_greater_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _greater_equal_docstring_

Expand Down Expand Up @@ -1041,7 +1042,7 @@
ti._less_result_type,
ti._less,
_less_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _less_docstring_

Expand Down Expand Up @@ -1077,7 +1078,7 @@
ti._less_equal_result_type,
ti._less_equal,
_less_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _less_equal_docstring_

Expand Down Expand Up @@ -1552,7 +1553,7 @@
ti._not_equal_result_type,
ti._not_equal,
_not_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _not_equal_docstring_

Expand Down
35 changes: 20 additions & 15 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
return o1_dtype, o2_dtype


def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050 for comparisons,"
"where result type is known to be `bool` and special behavior"
"is needed to handle mixed integer kinds"
def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050 for comparisons and"
" divide, where result type is known and special behavior"
"is needed to handle mixed integer kinds and Python integers"
"without overflow"
if _is_weak_dtype(o1_dtype):
if _is_weak_dtype(o2_dtype):
raise ValueError
Expand All @@ -414,11 +415,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
)
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
if isinstance(o1_dtype, WeakIntegralType):
if o2_dtype.kind == "u":
# Python scalar may be negative, assumes mixed int loops
# exist
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
if o1_kind_num == o2_kind_num and isinstance(
o1_dtype, WeakIntegralType
):
o1_val = o1_dtype.get()
o2_iinfo = dpt.iinfo(o2_dtype)
if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max):
return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype
return o2_dtype, o2_dtype
elif _is_weak_dtype(o2_dtype):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
Expand All @@ -435,11 +438,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
_to_device_supported_dtype(dpt.float64, dev),
)
else:
if isinstance(o2_dtype, WeakIntegralType):
if o1_dtype.kind == "u":
# Python scalar may be negative, assumes mixed int loops
# exist
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
if o1_kind_num == o2_kind_num and isinstance(
o2_dtype, WeakIntegralType
):
o2_val = o2_dtype.get()
o1_iinfo = dpt.iinfo(o1_dtype)
if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max):
return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val))
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype
Expand Down Expand Up @@ -834,7 +839,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"_acceptance_fn_negative",
"_acceptance_fn_subtract",
"_resolve_weak_types",
"_resolve_weak_types_comparisons",
"_resolve_weak_types_all_py_ints",
"_weak_type_num_kind",
"_strong_dtype_num_kind",
"can_cast",
Expand Down
15 changes: 15 additions & 0 deletions dpctl/tests/elementwise/test_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,18 @@ def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
else:
with pytest.raises(ValueError):
dpt.divide(ar1, ar2, out=ar2)


def test_divide_gh_1711():
"See https://github.com/IntelPython/dpctl/issues/1711"
get_queue_or_skip()

res = dpt.divide(-4, dpt.asarray(1, dtype="u4"))
assert isinstance(res, dpt.usm_ndarray)
assert res.dtype.kind == "f"
assert dpt.allclose(res, -4 / dpt.asarray(1, dtype="i4"))

res = dpt.divide(dpt.asarray(3, dtype="u4"), -2)
assert isinstance(res, dpt.usm_ndarray)
assert res.dtype.kind == "f"
assert dpt.allclose(res, dpt.asarray(3, dtype="i4") / -2)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_greater.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,17 @@ def test_greater_mixed_integer_kinds():
# Python scalar
assert dpt.all(dpt.greater(x2, -1))
assert not dpt.any(dpt.greater(-1, x2))


def test_greater_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert py_int > x
assert not dpt.greater(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert x > -1
assert not dpt.greater(-1, x)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_greater_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,17 @@ def test_greater_equal_mixed_integer_kinds():
# Python scalar
assert dpt.all(dpt.greater_equal(x2, -1))
assert not dpt.any(dpt.greater_equal(-1, x2))


def test_greater_equal_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert py_int >= x
assert not dpt.greater_equal(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert x >= -1
assert not dpt.greater_equal(-1, x)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_less.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,17 @@ def test_less_mixed_integer_kinds():
# Python scalar
assert not dpt.any(dpt.less(x2, -1))
assert dpt.all(dpt.less(-1, x2))


def test_less_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert not py_int < x
assert dpt.less(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert not x < -1
assert dpt.less(-1, x)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_less_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,17 @@ def test_less_equal_mixed_integer_kinds():
# Python scalar
assert not dpt.any(dpt.less_equal(x2, -1))
assert dpt.all(dpt.less_equal(-1, x2))


def test_less_equal_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert not py_int <= x
assert dpt.less_equal(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert not x <= -1
assert dpt.less_equal(-1, x)
Loading