Skip to content

Commit 81b55bf

Browse files
committed
Comparisons with signed and unsigned integers behave correctly
Comparions between signed and unsigned integer data previously did not work correctly in some cases, as signed integers could be promoted to uint64 if one input was uint64 Additionally, `-1 < x` for some `x` with unsigned integer data type would always fail, as the -1 would initialize an array of `x.dtype` which would always underflow, leading to undefined behavior These problems were addressed by adding signed and unsigned 64-bit integer combinations to the type maps for the comparisons, and adding constexpr branches to the comparisons between mixed-kind integers
1 parent fff92a6 commit 81b55bf

File tree

9 files changed

+249
-12
lines changed

9 files changed

+249
-12
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def __init__(
416416
docs,
417417
binary_inplace_fn=None,
418418
acceptance_fn=None,
419+
weak_type_resolver=None,
419420
):
420421
self.__name__ = "BinaryElementwiseFunc"
421422
self.name_ = name
@@ -428,6 +429,10 @@ def __init__(
428429
self.acceptance_fn_ = acceptance_fn
429430
else:
430431
self.acceptance_fn_ = _acceptance_fn_default_binary
432+
if callable(weak_type_resolver):
433+
self.weak_type_resolver_ = weak_type_resolver
434+
else:
435+
self.weak_type_resolver_ = _resolve_weak_types
431436

432437
def __str__(self):
433438
return f"<{self.__name__} '{self.name_}'>"
@@ -476,6 +481,22 @@ def get_type_promotion_path_acceptance_function(self):
476481
"""
477482
return self.acceptance_fn_
478483

484+
def get_array_dtype_scalar_type_resolver_function(self):
485+
"""Returns the function which determines how to treat
486+
Python scalar types for this elementwise binary function.
487+
488+
Resolver influences what type the scalar will be
489+
treated as prior to type promotion behavior.
490+
The function takes 3 arguments:
491+
o1_dtype - A class representing a Python scalar type or a dtype
492+
o2_dtype - A class representing a Python scalar type or a dtype
493+
sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation
494+
is carried out.
495+
496+
One of o1_dtype and o2_dtype must be a dtype
497+
"""
498+
return self.weak_type_resolver_
499+
479500
@property
480501
def nin(self):
481502
"""
@@ -579,7 +600,9 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
579600
if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)):
580601
raise ValueError("Operands have unsupported data types")
581602

582-
o1_dtype, o2_dtype = _resolve_weak_types(o1_dtype, o2_dtype, sycl_dev)
603+
o1_dtype, o2_dtype = self.weak_type_resolver_(
604+
o1_dtype, o2_dtype, sycl_dev
605+
)
583606

584607
buf1_dt, buf2_dt, res_dt = _find_buf_dtype2(
585608
o1_dtype,

dpctl/tensor/_elementwise_funcs.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_acceptance_fn_negative,
2323
_acceptance_fn_reciprocal,
2424
_acceptance_fn_subtract,
25+
_resolve_weak_types_comparisons,
2526
)
2627

2728
# U01: ==== ABS (x)
@@ -690,7 +691,11 @@
690691
"""
691692

692693
equal = BinaryElementwiseFunc(
693-
"equal", ti._equal_result_type, ti._equal, _equal_docstring_
694+
"equal",
695+
ti._equal_result_type,
696+
ti._equal,
697+
_equal_docstring_,
698+
weak_type_resolver=_resolve_weak_types_comparisons,
694699
)
695700
del _equal_docstring_
696701

@@ -845,7 +850,11 @@
845850
"""
846851

847852
greater = BinaryElementwiseFunc(
848-
"greater", ti._greater_result_type, ti._greater, _greater_docstring_
853+
"greater",
854+
ti._greater_result_type,
855+
ti._greater,
856+
_greater_docstring_,
857+
weak_type_resolver=_resolve_weak_types_comparisons,
849858
)
850859
del _greater_docstring_
851860

@@ -881,6 +890,7 @@
881890
ti._greater_equal_result_type,
882891
ti._greater_equal,
883892
_greater_equal_docstring_,
893+
weak_type_resolver=_resolve_weak_types_comparisons,
884894
)
885895
del _greater_equal_docstring_
886896

@@ -1027,7 +1037,11 @@
10271037
"""
10281038

10291039
less = BinaryElementwiseFunc(
1030-
"less", ti._less_result_type, ti._less, _less_docstring_
1040+
"less",
1041+
ti._less_result_type,
1042+
ti._less,
1043+
_less_docstring_,
1044+
weak_type_resolver=_resolve_weak_types_comparisons,
10311045
)
10321046
del _less_docstring_
10331047

@@ -1063,6 +1077,7 @@
10631077
ti._less_equal_result_type,
10641078
ti._less_equal,
10651079
_less_equal_docstring_,
1080+
weak_type_resolver=_resolve_weak_types_comparisons,
10661081
)
10671082
del _less_equal_docstring_
10681083

@@ -1499,7 +1514,11 @@
14991514
"""
15001515

15011516
not_equal = BinaryElementwiseFunc(
1502-
"not_equal", ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_
1517+
"not_equal",
1518+
ti._not_equal_result_type,
1519+
ti._not_equal,
1520+
_not_equal_docstring_,
1521+
weak_type_resolver=_resolve_weak_types_comparisons,
15031522
)
15041523
del _not_equal_docstring_
15051524

dpctl/tensor/_type_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,72 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
400400
return o1_dtype, o2_dtype
401401

402402

403+
def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
404+
"Resolves weak data type per NEP-0050 for comparisons,"
405+
"where result type is known to be `bool` and special behavior"
406+
"is needed to handle mixed integer kinds"
407+
if isinstance(
408+
o1_dtype,
409+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
410+
):
411+
if isinstance(
412+
o2_dtype,
413+
(
414+
WeakBooleanType,
415+
WeakIntegralType,
416+
WeakFloatingType,
417+
WeakComplexType,
418+
),
419+
):
420+
raise ValueError
421+
o1_kind_num = _weak_type_num_kind(o1_dtype)
422+
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
423+
if o1_kind_num > o2_kind_num:
424+
if isinstance(o1_dtype, WeakIntegralType):
425+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
426+
if isinstance(o1_dtype, WeakComplexType):
427+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
428+
return dpt.complex64, o2_dtype
429+
return (
430+
_to_device_supported_dtype(dpt.complex128, dev),
431+
o2_dtype,
432+
)
433+
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
434+
else:
435+
if isinstance(o1_dtype, WeakIntegralType):
436+
if o2_dtype.kind == "u":
437+
# Python scalar may be negative, assumes mixed int loops
438+
# exist
439+
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
440+
return o2_dtype, o2_dtype
441+
elif isinstance(
442+
o2_dtype,
443+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
444+
):
445+
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
446+
o2_kind_num = _weak_type_num_kind(o2_dtype)
447+
if o2_kind_num > o1_kind_num:
448+
if isinstance(o2_dtype, WeakIntegralType):
449+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
450+
if isinstance(o2_dtype, WeakComplexType):
451+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
452+
return o1_dtype, dpt.complex64
453+
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
454+
return (
455+
o1_dtype,
456+
_to_device_supported_dtype(dpt.float64, dev),
457+
)
458+
else:
459+
if isinstance(o2_dtype, WeakIntegralType):
460+
if o1_dtype.kind == "u":
461+
# Python scalar may be negative, assumes mixed int loops
462+
# exist
463+
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
464+
return o1_dtype, o1_dtype
465+
else:
466+
return o1_dtype, o2_dtype
467+
468+
403469
class finfo_object:
404470
"""
405471
`numpy.finfo` subclass which returns Python floating-point scalars for
@@ -789,6 +855,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
789855
"_acceptance_fn_negative",
790856
"_acceptance_fn_subtract",
791857
"_resolve_weak_types",
858+
"_resolve_weak_types_comparisons",
792859
"_weak_type_num_kind",
793860
"_strong_dtype_num_kind",
794861
"can_cast",

dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,25 @@ template <typename argT1, typename argT2, typename resT> struct EqualFunctor
7676
#endif
7777
}
7878
else {
79-
return (in1 == in2);
79+
if constexpr (std::is_integral_v<argT1> &&
80+
std::is_integral_v<argT2> &&
81+
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
82+
{
83+
if constexpr (std::is_signed_v<argT1> &&
84+
!std::is_signed_v<argT2>) {
85+
return (in1 < 0) ? false : (static_cast<argT2>(in1) == in2);
86+
}
87+
else {
88+
if constexpr (!std::is_signed_v<argT1> &&
89+
std::is_signed_v<argT2>) {
90+
return (in2 < 0) ? false
91+
: (in1 == static_cast<argT1>(in2));
92+
}
93+
}
94+
}
95+
else {
96+
return (in1 == in2);
97+
}
8098
}
8199
}
82100

@@ -151,6 +169,10 @@ template <typename T1, typename T2> struct EqualOutputType
151169
bool>,
152170
td_ns::
153171
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
172+
td_ns::
173+
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
174+
td_ns::
175+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
154176
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
155177
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
156178
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,25 @@ template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
7171
return greater_complex<argT1>(in1, in2);
7272
}
7373
else {
74-
return (in1 > in2);
74+
if constexpr (std::is_integral_v<argT1> &&
75+
std::is_integral_v<argT2> &&
76+
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
77+
{
78+
if constexpr (std::is_signed_v<argT1> &&
79+
!std::is_signed_v<argT2>) {
80+
return (in1 < 0) ? false : (static_cast<argT2>(in1) > in2);
81+
}
82+
else {
83+
if constexpr (!std::is_signed_v<argT1> &&
84+
std::is_signed_v<argT2>) {
85+
return (in2 < 0) ? true
86+
: (in1 > static_cast<argT1>(in2));
87+
}
88+
}
89+
}
90+
else {
91+
return (in1 > in2);
92+
}
7593
}
7694
}
7795

@@ -148,6 +166,10 @@ template <typename T1, typename T2> struct GreaterOutputType
148166
bool>,
149167
td_ns::
150168
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
169+
td_ns::
170+
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
171+
td_ns::
172+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
151173
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
152174
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
153175
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/greater_equal.hpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,25 @@ struct GreaterEqualFunctor
7272
return greater_equal_complex<argT1>(in1, in2);
7373
}
7474
else {
75-
return (in1 >= in2);
75+
if constexpr (std::is_integral_v<argT1> &&
76+
std::is_integral_v<argT2> &&
77+
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
78+
{
79+
if constexpr (std::is_signed_v<argT1> &&
80+
!std::is_signed_v<argT2>) {
81+
return (in1 < 0) ? false : (static_cast<argT2>(in1) >= in2);
82+
}
83+
else {
84+
if constexpr (!std::is_signed_v<argT1> &&
85+
std::is_signed_v<argT2>) {
86+
return (in2 < 0) ? true
87+
: (in1 >= static_cast<argT1>(in2));
88+
}
89+
}
90+
}
91+
else {
92+
return (in1 >= in2);
93+
}
7694
}
7795
}
7896

@@ -149,6 +167,10 @@ template <typename T1, typename T2> struct GreaterEqualOutputType
149167
bool>,
150168
td_ns::
151169
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
170+
td_ns::
171+
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
172+
td_ns::
173+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
152174
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
153175
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
154176
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/less.hpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,25 @@ template <typename argT1, typename argT2, typename resT> struct LessFunctor
7070
return less_complex<argT1>(in1, in2);
7171
}
7272
else {
73-
return (in1 < in2);
73+
if constexpr (std::is_integral_v<argT1> &&
74+
std::is_integral_v<argT2> &&
75+
std::is_signed_v<argT1> != std::is_signed_v<argT2>)
76+
{
77+
if constexpr (std::is_signed_v<argT1> &&
78+
!std::is_signed_v<argT2>) {
79+
return (in1 < 0) ? true : (static_cast<argT2>(in1) < in2);
80+
}
81+
else {
82+
if constexpr (!std::is_signed_v<argT1> &&
83+
std::is_signed_v<argT2>) {
84+
return (in2 < 0) ? false
85+
: (in1 < static_cast<argT1>(in2));
86+
}
87+
}
88+
}
89+
else {
90+
return (in1 < in2);
91+
}
7492
}
7593
}
7694

@@ -79,7 +97,6 @@ template <typename argT1, typename argT2, typename resT> struct LessFunctor
7997
operator()(const sycl::vec<argT1, vec_sz> &in1,
8098
const sycl::vec<argT2, vec_sz> &in2) const
8199
{
82-
83100
auto tmp = (in1 < in2);
84101

85102
if constexpr (std::is_same_v<resT,
@@ -147,6 +164,10 @@ template <typename T1, typename T2> struct LessOutputType
147164
bool>,
148165
td_ns::
149166
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
167+
td_ns::
168+
BinaryTypeMapResultEntry<T1, std::uint64_t, T2, std::int64_t, bool>,
169+
td_ns::
170+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::uint64_t, bool>,
150171
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
151172
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
152173
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,

0 commit comments

Comments
 (0)