Skip to content

Commit ec594cb

Browse files
committed
Replaces _resolve_weak_types_comparisons with _resolve_weak_types_all_py_ints
This new weak type resolver checks if the scalar is outside of the range of the strong data type and if so, returns the minimum scalar type for the value.
1 parent 7b64374 commit ec594cb

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
_acceptance_fn_negative,
2323
_acceptance_fn_reciprocal,
2424
_acceptance_fn_subtract,
25-
_resolve_weak_types_comparisons,
25+
_resolve_weak_types_all_py_ints,
2626
)
2727

2828
# U01: ==== ABS (x)
@@ -661,6 +661,7 @@
661661
_divide_docstring_,
662662
binary_inplace_fn=ti._divide_inplace,
663663
acceptance_fn=_acceptance_fn_divide,
664+
weak_type_resolver=_resolve_weak_types_all_py_ints,
664665
)
665666
del _divide_docstring_
666667

@@ -695,7 +696,7 @@
695696
ti._equal_result_type,
696697
ti._equal,
697698
_equal_docstring_,
698-
weak_type_resolver=_resolve_weak_types_comparisons,
699+
weak_type_resolver=_resolve_weak_types_all_py_ints,
699700
)
700701
del _equal_docstring_
701702

@@ -854,7 +855,7 @@
854855
ti._greater_result_type,
855856
ti._greater,
856857
_greater_docstring_,
857-
weak_type_resolver=_resolve_weak_types_comparisons,
858+
weak_type_resolver=_resolve_weak_types_all_py_ints,
858859
)
859860
del _greater_docstring_
860861

@@ -890,7 +891,7 @@
890891
ti._greater_equal_result_type,
891892
ti._greater_equal,
892893
_greater_equal_docstring_,
893-
weak_type_resolver=_resolve_weak_types_comparisons,
894+
weak_type_resolver=_resolve_weak_types_all_py_ints,
894895
)
895896
del _greater_equal_docstring_
896897

@@ -1041,7 +1042,7 @@
10411042
ti._less_result_type,
10421043
ti._less,
10431044
_less_docstring_,
1044-
weak_type_resolver=_resolve_weak_types_comparisons,
1045+
weak_type_resolver=_resolve_weak_types_all_py_ints,
10451046
)
10461047
del _less_docstring_
10471048

@@ -1077,7 +1078,7 @@
10771078
ti._less_equal_result_type,
10781079
ti._less_equal,
10791080
_less_equal_docstring_,
1080-
weak_type_resolver=_resolve_weak_types_comparisons,
1081+
weak_type_resolver=_resolve_weak_types_all_py_ints,
10811082
)
10821083
del _less_equal_docstring_
10831084

@@ -1552,7 +1553,7 @@
15521553
ti._not_equal_result_type,
15531554
ti._not_equal,
15541555
_not_equal_docstring_,
1555-
weak_type_resolver=_resolve_weak_types_comparisons,
1556+
weak_type_resolver=_resolve_weak_types_all_py_ints,
15561557
)
15571558
del _not_equal_docstring_
15581559

dpctl/tensor/_type_utils.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
393393
return o1_dtype, o2_dtype
394394

395395

396-
def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
397-
"Resolves weak data type per NEP-0050 for comparisons,"
398-
"where result type is known to be `bool` and special behavior"
399-
"is needed to handle mixed integer kinds"
396+
def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
397+
"Resolves weak data type per NEP-0050 for comparisons and"
398+
" divide, where result type is known and special behavior"
399+
"is needed to handle mixed integer kinds and Python integers"
400+
"without overflow"
400401
if _is_weak_dtype(o1_dtype):
401402
if _is_weak_dtype(o2_dtype):
402403
raise ValueError
@@ -415,10 +416,10 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
415416
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
416417
else:
417418
if isinstance(o1_dtype, WeakIntegralType):
418-
if o2_dtype.kind == "u":
419-
# Python scalar may be negative, assumes mixed int loops
420-
# exist
421-
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
419+
o1_val = o1_dtype.get()
420+
o2_iinfo = dpt.iinfo(o2_dtype)
421+
if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max):
422+
return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype
422423
return o2_dtype, o2_dtype
423424
elif _is_weak_dtype(o2_dtype):
424425
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
@@ -436,10 +437,10 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
436437
)
437438
else:
438439
if isinstance(o2_dtype, WeakIntegralType):
439-
if o1_dtype.kind == "u":
440-
# Python scalar may be negative, assumes mixed int loops
441-
# exist
442-
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
440+
o2_val = o2_dtype.get()
441+
o1_iinfo = dpt.iinfo(o1_dtype)
442+
if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max):
443+
return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val))
443444
return o1_dtype, o1_dtype
444445
else:
445446
return o1_dtype, o2_dtype
@@ -834,7 +835,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
834835
"_acceptance_fn_negative",
835836
"_acceptance_fn_subtract",
836837
"_resolve_weak_types",
837-
"_resolve_weak_types_comparisons",
838+
"_resolve_weak_types_all_py_ints",
838839
"_weak_type_num_kind",
839840
"_strong_dtype_num_kind",
840841
"can_cast",

0 commit comments

Comments
 (0)