File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -71,6 +71,7 @@ cdef extern from 'dlpack/dlpack.h' nogil:
71
71
kDLFloat
72
72
kDLBfloat
73
73
kDLComplex
74
+ kDLBool
74
75
75
76
ctypedef struct DLDataType:
76
77
uint8_t code
@@ -244,7 +245,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary) except+:
244
245
dl_tensor.dtype.lanes = < uint16_t> 1
245
246
dl_tensor.dtype.bits = < uint8_t> (ary_dt.itemsize * 8 )
246
247
if (ary_dtk == " b" ):
247
- dl_tensor.dtype.code = < uint8_t> kDLUInt
248
+ dl_tensor.dtype.code = < uint8_t> kDLBool
248
249
elif (ary_dtk == " u" ):
249
250
dl_tensor.dtype.code = < uint8_t> kDLUInt
250
251
elif (ary_dtk == " i" ):
@@ -444,6 +445,8 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
444
445
ary_dt = np.dtype(" f" + str (element_bytesize))
445
446
elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
446
447
ary_dt = np.dtype(" c" + str (element_bytesize))
448
+ elif (dlm_tensor.dl_tensor.dtype.code == kDLBool):
449
+ ary_dt = np.dtype(" ?" )
447
450
else :
448
451
raise BufferError(
449
452
" Can not import DLPack tensor with type code {}." .format(
You can’t perform that action at this time.
0 commit comments