Skip to content

Commit aaf444e

Browse files
authored
Fixes element-wise comparisons of mixed signed-unsigned integer inputs (#1650)
* Align `not_equal` type map to other comparisons Overloads for combinations of complex and real valued floats are unnecessary, as floats can be safely cast to complex * Comparisons with signed and unsigned integers behave correctly Comparions between signed and unsigned integer data previously did not work correctly in some cases, as signed integers could be promoted to uint64 if one input was uint64 Additionally, `-1 < x` for some `x` with unsigned integer data type would always fail, as the -1 would initialize an array of `x.dtype` which would always underflow, leading to undefined behavior These problems were addressed by adding signed and unsigned 64-bit integer combinations to the type maps for the comparisons, and adding constexpr branches to the comparisons between mixed-kind integers * Apply suggested docstring for get_array_dtype_scalar_type_resolver_function * Refactor `isinstance` checks for any weak types into `_is_weak_dtype` utility function Per suggestion by @oleksandr-pavlyk * Added tests for comparing unsigned integer arrays to negative integer arrays and Python scalars
1 parent 379b939 commit aaf444e

File tree

15 files changed

+340
-76
lines changed

15 files changed

+340
-76
lines changed

dpctl/tensor/_clip.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
from dpctl.utils import ExecutionPlacementError
3535

3636
from ._type_utils import (
37-
WeakBooleanType,
3837
WeakComplexType,
39-
WeakFloatingType,
4038
WeakIntegralType,
39+
_is_weak_dtype,
4140
_strong_dtype_num_kind,
4241
_weak_type_num_kind,
4342
)
@@ -47,29 +46,10 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
4746
"Resolves weak data types per NEP-0050,"
4847
"where the second and third arguments are"
4948
"permitted to be weak types"
50-
if isinstance(
51-
st_dtype,
52-
(
53-
WeakBooleanType,
54-
WeakIntegralType,
55-
WeakFloatingType,
56-
WeakComplexType,
57-
),
58-
):
49+
if _is_weak_dtype(st_dtype):
5950
raise ValueError
60-
if isinstance(
61-
dtype1,
62-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
63-
):
64-
if isinstance(
65-
dtype2,
66-
(
67-
WeakBooleanType,
68-
WeakIntegralType,
69-
WeakFloatingType,
70-
WeakComplexType,
71-
),
72-
):
51+
if _is_weak_dtype(dtype1):
52+
if _is_weak_dtype(dtype2):
7353
kind_num1 = _weak_type_num_kind(dtype1)
7454
kind_num2 = _weak_type_num_kind(dtype2)
7555
st_kind_num = _strong_dtype_num_kind(st_dtype)
@@ -120,10 +100,7 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
120100
return _to_device_supported_dtype(dpt.float64, dev), dtype2
121101
else:
122102
return max_dtype, dtype2
123-
elif isinstance(
124-
dtype2,
125-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
126-
):
103+
elif _is_weak_dtype(dtype2):
127104
max_dt_num_kind, max_dtype = max(
128105
[
129106
(_strong_dtype_num_kind(st_dtype), st_dtype),
@@ -152,15 +129,9 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
152129

153130
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
154131
"Resolves one weak data type with one strong data type per NEP-0050"
155-
if isinstance(
156-
st_dtype,
157-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
158-
):
132+
if _is_weak_dtype(st_dtype):
159133
raise ValueError
160-
if isinstance(
161-
dtype,
162-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
163-
):
134+
if _is_weak_dtype(dtype):
164135
st_kind_num = _strong_dtype_num_kind(st_dtype)
165136
kind_num = _weak_type_num_kind(dtype)
166137
if kind_num > st_kind_num:

dpctl/tensor/_elementwise_common.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def __init__(
416416
docs,
417417
binary_inplace_fn=None,
418418
acceptance_fn=None,
419+
weak_type_resolver=None,
419420
):
420421
self.__name__ = "BinaryElementwiseFunc"
421422
self.name_ = name
@@ -428,6 +429,10 @@ def __init__(
428429
self.acceptance_fn_ = acceptance_fn
429430
else:
430431
self.acceptance_fn_ = _acceptance_fn_default_binary
432+
if callable(weak_type_resolver):
433+
self.weak_type_resolver_ = weak_type_resolver
434+
else:
435+
self.weak_type_resolver_ = _resolve_weak_types
431436

432437
def __str__(self):
433438
return f"<{self.__name__} '{self.name_}'>"
@@ -476,6 +481,26 @@ def get_type_promotion_path_acceptance_function(self):
476481
"""
477482
return self.acceptance_fn_
478483

484+
def get_array_dtype_scalar_type_resolver_function(self):
485+
"""Returns the function which determines how to treat
486+
Python scalar types for this elementwise binary function.
487+
488+
Resolver influences what type the scalar will be
489+
treated as prior to type promotion behavior.
490+
The function takes 3 arguments:
491+
492+
Args:
493+
o1_dtype (object, dtype):
494+
A class representing a Python scalar type or a ``dtype``
495+
o2_dtype (object, dtype):
496+
A class representing a Python scalar type or a ``dtype``
497+
sycl_dev (:class:`dpctl.SyclDevice`):
498+
Device on which function evaluation is carried out.
499+
500+
One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance.
501+
"""
502+
return self.weak_type_resolver_
503+
479504
@property
480505
def nin(self):
481506
"""
@@ -579,7 +604,9 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
579604
if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
580605
raise ValueError("Operands have unsupported data types")
581606

582-
o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
607+
o1_dtype, o2_dtype = self.weak_type_resolver_(
608+
o1_dtype, o2_dtype, sycl_dev
609+
)
583610

584611
buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
585612
o1_dtype,

dpctl/tensor/_elementwise_funcs.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_acceptance_fn_negative,
2323
_acceptance_fn_reciprocal,
2424
_acceptance_fn_subtract,
25+
_resolve_weak_types_comparisons,
2526
)
2627

2728
# U01: ==== ABS (x)
@@ -690,7 +691,11 @@
690691
"""
691692

692693
equal = BinaryElementwiseFunc(
693-
"equal", ti._equal_result_type, ti._equal, _equal_docstring_
694+
"equal",
695+
ti._equal_result_type,
696+
ti._equal,
697+
_equal_docstring_,
698+
weak_type_resolver=_resolve_weak_types_comparisons,
694699
)
695700
del _equal_docstring_
696701

@@ -845,7 +850,11 @@
845850
"""
846851

847852
greater = BinaryElementwiseFunc(
848-
"greater", ti._greater_result_type, ti._greater, _greater_docstring_
853+
"greater",
854+
ti._greater_result_type,
855+
ti._greater,
856+
_greater_docstring_,
857+
weak_type_resolver=_resolve_weak_types_comparisons,
849858
)
850859
del _greater_docstring_
851860

@@ -881,6 +890,7 @@
881890
ti._greater_equal_result_type,
882891
ti._greater_equal,
883892
_greater_equal_docstring_,
893+
weak_type_resolver=_resolve_weak_types_comparisons,
884894
)
885895
del _greater_equal_docstring_
886896

@@ -1027,7 +1037,11 @@
10271037
"""
10281038

10291039
less = BinaryElementwiseFunc(
1030-
"less", ti._less_result_type, ti._less, _less_docstring_
1040+
"less",
1041+
ti._less_result_type,
1042+
ti._less,
1043+
_less_docstring_,
1044+
weak_type_resolver=_resolve_weak_types_comparisons,
10311045
)
10321046
del _less_docstring_
10331047

@@ -1063,6 +1077,7 @@
10631077
ti._less_equal_result_type,
10641078
ti._less_equal,
10651079
_less_equal_docstring_,
1080+
weak_type_resolver=_resolve_weak_types_comparisons,
10661081
)
10671082
del _less_equal_docstring_
10681083

@@ -1499,7 +1514,11 @@
14991514
"""
15001515

15011516
not_equal = BinaryElementwiseFunc(
1502-
"not_equal", ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_
1517+
"not_equal",
1518+
ti._not_equal_result_type,
1519+
ti._not_equal,
1520+
_not_equal_docstring_,
1521+
weak_type_resolver=_resolve_weak_types_comparisons,
15031522
)
15041523
del _not_equal_docstring_
15051524

dpctl/tensor/_type_utils.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -346,21 +346,17 @@ def _strong_dtype_num_kind(o):
346346
raise ValueError(f"Unrecognized kind {k} for dtype {o}")
347347

348348

349+
def _is_weak_dtype(dtype):
350+
return isinstance(
351+
dtype,
352+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
353+
)
354+
355+
349356
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
350357
"Resolves weak data type per NEP-0050"
351-
if isinstance(
352-
o1_dtype,
353-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
354-
):
355-
if isinstance(
356-
o2_dtype,
357-
(
358-
WeakBooleanType,
359-
WeakIntegralType,
360-
WeakFloatingType,
361-
WeakComplexType,
362-
),
363-
):
358+
if _is_weak_dtype(o1_dtype):
359+
if _is_weak_dtype(o2_dtype):
364360
raise ValueError
365361
o1_kind_num = _weak_type_num_kind(o1_dtype)
366362
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
@@ -377,10 +373,54 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
377373
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
378374
else:
379375
return o2_dtype, o2_dtype
380-
elif isinstance(
381-
o2_dtype,
382-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
383-
):
376+
elif _is_weak_dtype(o2_dtype):
377+
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
378+
o2_kind_num = _weak_type_num_kind(o2_dtype)
379+
if o2_kind_num > o1_kind_num:
380+
if isinstance(o2_dtype, WeakIntegralType):
381+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
382+
if isinstance(o2_dtype, WeakComplexType):
383+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
384+
return o1_dtype, dpt.complex64
385+
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
386+
return (
387+
o1_dtype,
388+
_to_device_supported_dtype(dpt.float64, dev),
389+
)
390+
else:
391+
return o1_dtype, o1_dtype
392+
else:
393+
return o1_dtype, o2_dtype
394+
395+
396+
def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
397+
"Resolves weak data type per NEP-0050 for comparisons,"
398+
"where result type is known to be `bool` and special behavior"
399+
"is needed to handle mixed integer kinds"
400+
if _is_weak_dtype(o1_dtype):
401+
if _is_weak_dtype(o2_dtype):
402+
raise ValueError
403+
o1_kind_num = _weak_type_num_kind(o1_dtype)
404+
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
405+
if o1_kind_num > o2_kind_num:
406+
if isinstance(o1_dtype, WeakIntegralType):
407+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
408+
if isinstance(o1_dtype, WeakComplexType):
409+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
410+
return dpt.complex64, o2_dtype
411+
return (
412+
_to_device_supported_dtype(dpt.complex128, dev),
413+
o2_dtype,
414+
)
415+
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
416+
else:
417+
if isinstance(o1_dtype, WeakIntegralType):
418+
if o2_dtype.kind == "u":
419+
# Python scalar may be negative, assumes mixed int loops
420+
# exist
421+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
422+
return o2_dtype, o2_dtype
423+
elif _is_weak_dtype(o2_dtype):
384424
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
385425
o2_kind_num = _weak_type_num_kind(o2_dtype)
386426
if o2_kind_num > o1_kind_num:
@@ -395,6 +435,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
395435
_to_device_supported_dtype(dpt.float64, dev),
396436
)
397437
else:
438+
if isinstance(o2_dtype, WeakIntegralType):
439+
if o1_dtype.kind == "u":
440+
# Python scalar may be negative, assumes mixed int loops
441+
# exist
442+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
398443
return o1_dtype, o1_dtype
399444
else:
400445
return o1_dtype, o2_dtype
@@ -789,6 +834,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
789834
"_acceptance_fn_negative",
790835
"_acceptance_fn_subtract",
791836
"_resolve_weak_types",
837+
"_resolve_weak_types_comparisons",
792838
"_weak_type_num_kind",
793839
"_strong_dtype_num_kind",
794840
"can_cast",

dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,25 @@ template <typename argT1, typename argT2, typename resT> struct EqualFunctor
7676
#endif
7777
}
7878
else {
79-
return (in1 == in2);
79+
if constexpr (std::is_integral_v<argT1> &&
80+
std::is_integral_v<argT2> &&
81+
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
82+
{
83+
if constexpr (std::is_signed_v<argT1> &&
84+
!std::is_signed_v<argT2>) {
85+
return (in1 < 0) ? false : (static_cast<argT2>(in1) == in2);
86+
}
87+
else {
88+
if constexpr (!std::is_signed_v<argT1> &&
89+
std::is_signed_v<argT2>) {
90+
return (in2 < 0) ? false
91+
: (in1 == static_cast<argT1>(in2));
92+
}
93+
}
94+
}
95+
else {
96+
return (in1 == in2);
97+
}
8098
}
8199
}
82100

@@ -151,6 +169,10 @@ template <typename T1, typename T2> struct EqualOutputType
151169
bool>,
152170
td_ns::
153171
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
172+
td_ns::
173+
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
174+
td_ns::
175+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
154176
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
155177
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
156178
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,25 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
7171
return greater_complex<argT1>(in1, in2);
7272
}
7373
else {
74-
return (in1 > in2);
74+
if constexpr (std::is_integral_v<argT1> &&
75+
std::is_integral_v<argT2> &&
76+
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
77+
{
78+
if constexpr (std::is_signed_v<argT1> &&
79+
!std::is_signed_v<argT2>) {
80+
return (in1 < 0) ? false : (static_cast<argT2>(in1) > in2);
81+
}
82+
else {
83+
if constexpr (!std::is_signed_v<argT1> &&
84+
std::is_signed_v<argT2>) {
85+
return (in2 < 0) ? true
86+
: (in1 > static_cast<argT1>(in2));
87+
}
88+
}
89+
}
90+
else {
91+
return (in1 > in2);
92+
}
7593
}
7694
}
7795

@@ -148,6 +166,10 @@ template <typename T1, typename T2> struct GreaterOutputType
148166
bool>,
149167
td_ns::
150168
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
169+
td_ns::
170+
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
171+
td_ns::
172+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
151173
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
152174
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
153175
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,

0 commit comments

Comments
 (0)