@@ -346,21 +346,17 @@ def _strong_dtype_num_kind(o):
346
346
raise ValueError (f"Unrecognized kind { k } for dtype { o } " )
347
347
348
348
349
+ def _is_weak_dtype (dtype ):
350
+ return isinstance (
351
+ dtype ,
352
+ (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
353
+ )
354
+
355
+
349
356
def _resolve_weak_types (o1_dtype , o2_dtype , dev ):
350
357
"Resolves weak data type per NEP-0050"
351
- if isinstance (
352
- o1_dtype ,
353
- (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
354
- ):
355
- if isinstance (
356
- o2_dtype ,
357
- (
358
- WeakBooleanType ,
359
- WeakIntegralType ,
360
- WeakFloatingType ,
361
- WeakComplexType ,
362
- ),
363
- ):
358
+ if _is_weak_dtype (o1_dtype ):
359
+ if _is_weak_dtype (o2_dtype ):
364
360
raise ValueError
365
361
o1_kind_num = _weak_type_num_kind (o1_dtype )
366
362
o2_kind_num = _strong_dtype_num_kind (o2_dtype )
@@ -377,10 +373,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
377
373
return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
378
374
else :
379
375
return o2_dtype , o2_dtype
380
- elif isinstance (
381
- o2_dtype ,
382
- (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
383
- ):
376
+ elif _is_weak_dtype (o2_dtype ):
384
377
o1_kind_num = _strong_dtype_num_kind (o1_dtype )
385
378
o2_kind_num = _weak_type_num_kind (o2_dtype )
386
379
if o2_kind_num > o1_kind_num :
@@ -404,19 +397,8 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
404
397
"Resolves weak data type per NEP-0050 for comparisons,"
405
398
"where result type is known to be `bool` and special behavior"
406
399
"is needed to handle mixed integer kinds"
407
- if isinstance (
408
- o1_dtype ,
409
- (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
410
- ):
411
- if isinstance (
412
- o2_dtype ,
413
- (
414
- WeakBooleanType ,
415
- WeakIntegralType ,
416
- WeakFloatingType ,
417
- WeakComplexType ,
418
- ),
419
- ):
400
+ if _is_weak_dtype (o1_dtype ):
401
+ if _is_weak_dtype (o2_dtype ):
420
402
raise ValueError
421
403
o1_kind_num = _weak_type_num_kind (o1_dtype )
422
404
o2_kind_num = _strong_dtype_num_kind (o2_dtype )
@@ -438,10 +420,7 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
438
420
# exist
439
421
return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
440
422
return o2_dtype , o2_dtype
441
- elif isinstance (
442
- o2_dtype ,
443
- (WeakBooleanType , WeakIntegralType , WeakFloatingType , WeakComplexType ),
444
- ):
423
+ elif _is_weak_dtype (o2_dtype ):
445
424
o1_kind_num = _strong_dtype_num_kind (o1_dtype )
446
425
o2_kind_num = _weak_type_num_kind (o2_dtype )
447
426
if o2_kind_num > o1_kind_num :
0 commit comments