Skip to content

Commit 0bc9bac

Browse files
Changes to support kDLBool type added in DLPack v0.8
1 parent 7658010 commit 0bc9bac

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ cdef extern from 'dlpack/dlpack.h' nogil:
7171
kDLFloat
7272
kDLBfloat
7373
kDLComplex
74+
kDLBool
7475

7576
ctypedef struct DLDataType:
7677
uint8_t code
@@ -244,7 +245,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
244245
dl_tensor.dtype.lanes = <uint16_t>1
245246
dl_tensor.dtype.bits = <uint8_t>(ary_dt.itemsize * 8)
246247
if (ary_dtk == "b"):
247-
dl_tensor.dtype.code = <uint8_t>kDLUInt
248+
dl_tensor.dtype.code = <uint8_t>kDLBool
248249
elif (ary_dtk == "u"):
249250
dl_tensor.dtype.code = <uint8_t>kDLUInt
250251
elif (ary_dtk == "i"):
@@ -444,6 +445,8 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
444445
ary_dt = np.dtype("f" + str(element_bytesize))
445446
elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
446447
ary_dt = np.dtype("c" + str(element_bytesize))
448+
elif (dlm_tensor.dl_tensor.dtype.code == kDLBool):
449+
ary_dt = np.dtype("?")
447450
else:
448451
raise BufferError(
449452
"Can not import DLPack tensor with type code {}.".format(

0 commit comments

Comments
 (0)