Skip to content

Commit a5c4fce

Browse files
committed
Refactor isinstance checks for any weak types into _is_weak_dtype utility function
Per suggestion by @oleksandr-pavlyk
1 parent 2331c1c commit a5c4fce

File tree

2 files changed

+20
-70
lines changed

2 files changed

+20
-70
lines changed

dpctl/tensor/_clip.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
from dpctl.utils import ExecutionPlacementError
3535

3636
from ._type_utils import (
37-
WeakBooleanType,
3837
WeakComplexType,
39-
WeakFloatingType,
4038
WeakIntegralType,
39+
_is_weak_dtype,
4140
_strong_dtype_num_kind,
4241
_weak_type_num_kind,
4342
)
@@ -47,29 +46,10 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
4746
"Resolves weak data types per NEP-0050,"
4847
"where the second and third arguments are"
4948
"permitted to be weak types"
50-
if isinstance(
51-
st_dtype,
52-
(
53-
WeakBooleanType,
54-
WeakIntegralType,
55-
WeakFloatingType,
56-
WeakComplexType,
57-
),
58-
):
49+
if _is_weak_dtype(st_dtype):
5950
raise ValueError
60-
if isinstance(
61-
dtype1,
62-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
63-
):
64-
if isinstance(
65-
dtype2,
66-
(
67-
WeakBooleanType,
68-
WeakIntegralType,
69-
WeakFloatingType,
70-
WeakComplexType,
71-
),
72-
):
51+
if _is_weak_dtype(dtype1):
52+
if _is_weak_dtype(dtype2):
7353
kind_num1 = _weak_type_num_kind(dtype1)
7454
kind_num2 = _weak_type_num_kind(dtype2)
7555
st_kind_num = _strong_dtype_num_kind(st_dtype)
@@ -120,10 +100,7 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
120100
return _to_device_supported_dtype(dpt.float64, dev), dtype2
121101
else:
122102
return max_dtype, dtype2
123-
elif isinstance(
124-
dtype2,
125-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
126-
):
103+
elif _is_weak_dtype(dtype2):
127104
max_dt_num_kind, max_dtype = max(
128105
[
129106
(_strong_dtype_num_kind(st_dtype), st_dtype),
@@ -152,15 +129,9 @@ def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
152129

153130
def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
154131
"Resolves one weak data type with one strong data type per NEP-0050"
155-
if isinstance(
156-
st_dtype,
157-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
158-
):
132+
if _is_weak_dtype(st_dtype):
159133
raise ValueError
160-
if isinstance(
161-
dtype,
162-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
163-
):
134+
if _is_weak_dtype(dtype):
164135
st_kind_num = _strong_dtype_num_kind(st_dtype)
165136
kind_num = _weak_type_num_kind(dtype)
166137
if kind_num > st_kind_num:

dpctl/tensor/_type_utils.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -346,21 +346,17 @@ def _strong_dtype_num_kind(o):
346346
raise ValueError(f"Unrecognized kind {k} for dtype {o}")
347347

348348

349+
def _is_weak_dtype(dtype):
350+
return isinstance(
351+
dtype,
352+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
353+
)
354+
355+
349356
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
350357
"Resolves weak data type per NEP-0050"
351-
if isinstance(
352-
o1_dtype,
353-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
354-
):
355-
if isinstance(
356-
o2_dtype,
357-
(
358-
WeakBooleanType,
359-
WeakIntegralType,
360-
WeakFloatingType,
361-
WeakComplexType,
362-
),
363-
):
358+
if _is_weak_dtype(o1_dtype):
359+
if _is_weak_dtype(o2_dtype):
364360
raise ValueError
365361
o1_kind_num = _weak_type_num_kind(o1_dtype)
366362
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
@@ -377,10 +373,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
377373
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
378374
else:
379375
return o2_dtype, o2_dtype
380-
elif isinstance(
381-
o2_dtype,
382-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
383-
):
376+
elif _is_weak_dtype(o2_dtype):
384377
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
385378
o2_kind_num = _weak_type_num_kind(o2_dtype)
386379
if o2_kind_num > o1_kind_num:
@@ -404,19 +397,8 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
404397
"Resolves weak data type per NEP-0050 for comparisons,"
405398
"where result type is known to be `bool` and special behavior"
406399
"is needed to handle mixed integer kinds"
407-
if isinstance(
408-
o1_dtype,
409-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
410-
):
411-
if isinstance(
412-
o2_dtype,
413-
(
414-
WeakBooleanType,
415-
WeakIntegralType,
416-
WeakFloatingType,
417-
WeakComplexType,
418-
),
419-
):
400+
if _is_weak_dtype(o1_dtype):
401+
if _is_weak_dtype(o2_dtype):
420402
raise ValueError
421403
o1_kind_num = _weak_type_num_kind(o1_dtype)
422404
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
@@ -438,10 +420,7 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
438420
# exist
439421
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
440422
return o2_dtype, o2_dtype
441-
elif isinstance(
442-
o2_dtype,
443-
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
444-
):
423+
elif _is_weak_dtype(o2_dtype):
445424
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
446425
o2_kind_num = _weak_type_num_kind(o2_dtype)
447426
if o2_kind_num > o1_kind_num:

0 commit comments

Comments
 (0)