Skip to content

Commit 58c8b25

Browse files
author
Prashant Kumar
committed
Replacing is with == for the dtype check.
>>> a = np.ndarray([1,1]).astype(np.half) >>> a array([[0.007812]], dtype=float16) >>> a.dtype dtype('float16') >>> a.dtype == np.half True >>> a.dtype == np.float16 True >>> a.dtype is np.float16 False Checking with `is` leads to inconsistency in checking. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D139121
1 parent 0fb74d0 commit 58c8b25

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

mlir/python/mlir/runtime/np_to_memref.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ class F16(ctypes.Structure):
2323
_fields_ = [("f16", ctypes.c_int16)]
2424

2525

26+
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
2627
def as_ctype(dtp):
2728
"""Converts dtype to ctype."""
28-
if dtp is np.dtype(np.complex128):
29+
if dtp == np.dtype(np.complex128):
2930
return C128
30-
if dtp is np.dtype(np.complex64):
31+
if dtp == np.dtype(np.complex64):
3132
return C64
32-
if dtp is np.dtype(np.float16):
33+
if dtp == np.dtype(np.float16):
3334
return F16
3435
return np.ctypeslib.as_ctypes_type(dtp)
3536

0 commit comments

Comments
 (0)