Skip to content

Commit 7b65822

Browse files
committed
add complex
1 parent df4f45e commit 7b65822

File tree

6 files changed

+24
-0
lines changed

6 files changed

+24
-0
lines changed

onnx_array_api/_helpers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def np_dtype_to_tensor_dtype(dtype: Any):
4040
dt = TensorProto.INT64
4141
elif dtype is float:
4242
dt = TensorProto.DOUBLE
43+
elif dtype == np.complex64:
44+
dt = TensorProto.COMPLEX64
45+
elif dtype == np.complex128:
46+
dt = TensorProto.COMPLEX128
4347
else:
4448
raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904
4549
return dt

onnx_array_api/annotations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
6464
np.uint64: TensorProto.UINT64,
6565
np.bool_: TensorProto.BOOL,
6666
np.str_: TensorProto.STRING,
67+
np.complex64: TensorProto.COMPLEX64,
68+
np.complex128: TensorProto.COMPLEX128,
6769
}
6870

6971

onnx_array_api/array_api/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _finfo(dtype):
4747
continue
4848
if isinstance(v, (np.float32, np.float64, np.float16)):
4949
d[k] = float(v)
50+
elif isinstance(v, (np.complex128, np.complex64)):
51+
d[k] = complex(v)
5052
else:
5153
d[k] = v
5254
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
@@ -124,6 +126,8 @@ def _finalize_array_api(module, function_names, TEagerTensor):
124126
module.float16 = DType(TensorProto.FLOAT16)
125127
module.float32 = DType(TensorProto.FLOAT)
126128
module.float64 = DType(TensorProto.DOUBLE)
129+
module.complex64 = DType(TensorProto.COMPLEX64)
130+
module.complex128 = DType(TensorProto.COMPLEX128)
127131
module.int8 = DType(TensorProto.INT8)
128132
module.int16 = DType(TensorProto.INT16)
129133
module.int32 = DType(TensorProto.INT32)

onnx_array_api/npx/npx_var.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,13 +1171,17 @@ def __init__(self, cst: Any):
11711171
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
11721172
elif isinstance(cst, float):
11731173
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
1174+
elif isinstance(cst, complex):
1175+
Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity")
11741176
elif isinstance(cst, list):
11751177
if all(isinstance(t, bool) for t in cst):
11761178
Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
11771179
elif all(isinstance(t, (int, bool)) for t in cst):
11781180
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
11791181
elif all(isinstance(t, (float, int, bool)) for t in cst):
11801182
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
1183+
elif all(isinstance(t, (float, int, bool, complex)) for t in cst):
1184+
Var.__init__(self, np.array(cst, dtype=np.complex128), op="Identity")
11811185
else:
11821186
raise ValueError(
11831187
f"Unable to convert cst (type={type(cst)}), value={cst}."

onnx_array_api/reference/evaluator_yield.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,12 @@ def generate_input(info: ValueInfoProto) -> np.ndarray:
485485
return (value.astype(np.float16) / p).astype(np.float16).reshape(new_shape)
486486
if elem_type == TensorProto.DOUBLE:
487487
return (value.astype(np.float64) / p).astype(np.float64).reshape(new_shape)
488+
if elem_type == TensorProto.COMPLEX64:
489+
return (value.astype(np.complex64) / p).astype(np.complex64).reshape(new_shape)
490+
if elem_type == TensorProto.COMPLEX128:
491+
return (
492+
(value.astype(np.complex128) / p).astype(np.complex128).reshape(new_shape)
493+
)
488494
raise RuntimeError(f"Unexpected element_type {elem_type} for info={info}")
489495

490496

onnx_array_api/reference/ops/op_constant_of_shape.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def _process(value):
1919
cst = np.int64(cst)
2020
elif isinstance(cst, float):
2121
cst = np.float64(cst)
22+
elif isinstance(cst, complex):
23+
cst = np.complex128(cst)
2224
elif cst is None:
2325
cst = np.float32(0)
2426
if not isinstance(
@@ -27,6 +29,8 @@ def _process(value):
2729
np.float16,
2830
np.float32,
2931
np.float64,
32+
np.complex64,
33+
np.complex128,
3034
np.int64,
3135
np.int32,
3236
np.int16,

0 commit comments

Comments
 (0)