@@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
393
393
return o1_dtype , o2_dtype
394
394
395
395
396
- def _resolve_weak_types_comparisons (o1_dtype , o2_dtype , dev ):
397
- "Resolves weak data type per NEP-0050 for comparisons,"
398
- "where result type is known to be `bool` and special behavior"
399
- "is needed to handle mixed integer kinds"
396
+ def _resolve_weak_types_all_py_ints (o1_dtype , o2_dtype , dev ):
397
+ "Resolves weak data type per NEP-0050 for comparisons and"
398
+ " divide, where result type is known and special behavior"
399
+ "is needed to handle mixed integer kinds and Python integers"
400
+ "without overflow"
400
401
if _is_weak_dtype (o1_dtype ):
401
402
if _is_weak_dtype (o2_dtype ):
402
403
raise ValueError
@@ -415,10 +416,10 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
415
416
return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
416
417
else :
417
418
if isinstance (o1_dtype , WeakIntegralType ):
418
- if o2_dtype . kind == "u" :
419
- # Python scalar may be negative, assumes mixed int loops
420
- # exist
421
- return dpt .dtype (ti . default_device_int_type ( dev )), o2_dtype
419
+ o1_val = o1_dtype . get ()
420
+ o2_iinfo = dpt . iinfo ( o2_dtype )
421
+ if ( o1_val < o2_iinfo . min ) or ( o1_val > o2_iinfo . max ):
422
+ return dpt .dtype (np . min_scalar_type ( o1_val )), o2_dtype
422
423
return o2_dtype , o2_dtype
423
424
elif _is_weak_dtype (o2_dtype ):
424
425
o1_kind_num = _strong_dtype_num_kind (o1_dtype )
@@ -436,10 +437,10 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
436
437
)
437
438
else :
438
439
if isinstance (o2_dtype , WeakIntegralType ):
439
- if o1_dtype . kind == "u" :
440
- # Python scalar may be negative, assumes mixed int loops
441
- # exist
442
- return o1_dtype , dpt .dtype (ti . default_device_int_type ( dev ))
440
+ o2_val = o2_dtype . get ()
441
+ o1_iinfo = dpt . iinfo ( o1_dtype )
442
+ if ( o2_val < o1_iinfo . min ) or ( o2_val > o1_iinfo . max ):
443
+ return o1_dtype , dpt .dtype (np . min_scalar_type ( o2_val ))
443
444
return o1_dtype , o1_dtype
444
445
else :
445
446
return o1_dtype , o2_dtype
@@ -834,7 +835,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
834
835
"_acceptance_fn_negative" ,
835
836
"_acceptance_fn_subtract" ,
836
837
"_resolve_weak_types" ,
837
- "_resolve_weak_types_comparisons " ,
838
+ "_resolve_weak_types_all_py_ints " ,
838
839
"_weak_type_num_kind" ,
839
840
"_strong_dtype_num_kind" ,
840
841
"can_cast" ,
0 commit comments