@@ -265,13 +265,35 @@ def __float__(self):
265
265
DType (TensorProto .DOUBLE ),
266
266
DType (TensorProto .FLOAT16 ),
267
267
DType (TensorProto .BFLOAT16 ),
268
+ DType (TensorProto .COMPLEX64 ),
269
+ DType (TensorProto .COMPLEX128 ),
268
270
}:
269
271
raise TypeError (
270
272
f"Conversion to float only works for float scalar, "
271
273
f"not for dtype={ self .dtype } ."
272
274
)
273
275
return float (self ._tensor )
274
276
277
+ def __complex__ (self ):
278
+ "Implicit conversion to complex."
279
+ if self .shape :
280
+ raise ValueError (
281
+ f"Conversion to bool only works for scalar, not for { self !r} ."
282
+ )
283
+ if self .dtype not in {
284
+ DType (TensorProto .FLOAT ),
285
+ DType (TensorProto .DOUBLE ),
286
+ DType (TensorProto .FLOAT16 ),
287
+ DType (TensorProto .BFLOAT16 ),
288
+ DType (TensorProto .COMPLEX64 ),
289
+ DType (TensorProto .COMPLEX128 ),
290
+ }:
291
+ raise TypeError (
292
+ f"Conversion to float only works for float scalar, "
293
+ f"not for dtype={ self .dtype } ."
294
+ )
295
+ return complex (self ._tensor )
296
+
275
297
def __iter__ (self ):
276
298
"""
277
299
The :epkg:`Array API` does not define this function (2022/12).
0 commit comments