Skip to content

Commit e8c3c42

Browse files
committed
add missing line
1 parent 7b65822 commit e8c3c42

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,35 @@ def __float__(self):
265265
DType(TensorProto.DOUBLE),
266266
DType(TensorProto.FLOAT16),
267267
DType(TensorProto.BFLOAT16),
268+
DType(TensorProto.COMPLEX64),
269+
DType(TensorProto.COMPLEX128),
268270
}:
269271
raise TypeError(
270272
f"Conversion to float only works for float scalar, "
271273
f"not for dtype={self.dtype}."
272274
)
273275
return float(self._tensor)
274276

277+
def __complex__(self):
278+
"Implicit conversion to complex."
279+
if self.shape:
280+
raise ValueError(
281+
f"Conversion to bool only works for scalar, not for {self!r}."
282+
)
283+
if self.dtype not in {
284+
DType(TensorProto.FLOAT),
285+
DType(TensorProto.DOUBLE),
286+
DType(TensorProto.FLOAT16),
287+
DType(TensorProto.BFLOAT16),
288+
DType(TensorProto.COMPLEX64),
289+
DType(TensorProto.COMPLEX128),
290+
}:
291+
raise TypeError(
292+
f"Conversion to float only works for float scalar, "
293+
f"not for dtype={self.dtype}."
294+
)
295+
return complex(self._tensor)
296+
275297
def __iter__(self):
276298
"""
277299
The :epkg:`Array API` does not define this function (2022/12).

0 commit comments

Comments
 (0)