Skip to content

Commit 48cc745

Browse files
committed
Fix numpy/cupy sum(), prod(), and trace() for complex upcasting with dtype=None
1 parent 398279b commit 48cc745

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

array_api_compat/common/_aliases.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,12 @@ def sum(
396396
keepdims: bool = False,
397397
**kwargs,
398398
) -> ndarray:
399-
# `xp.sum` already upcasts integers, but not floats
400-
if dtype is None and x.dtype == xp.float32:
401-
dtype = xp.float64
399+
# `xp.sum` already upcasts integers, but not floats or complexes
400+
if dtype is None:
401+
if x.dtype == xp.float32:
402+
dtype = xp.float64
403+
elif x.dtype == xp.complex64:
404+
dtype = xp.complex128
402405
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
403406

404407
def prod(
@@ -411,8 +414,11 @@ def prod(
411414
keepdims: bool = False,
412415
**kwargs,
413416
) -> ndarray:
414-
if dtype is None and x.dtype == xp.float32:
415-
dtype = xp.float64
417+
if dtype is None:
418+
if x.dtype == xp.float32:
419+
dtype = xp.float64
420+
elif x.dtype == xp.complex64:
421+
dtype = xp.complex128
416422
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
417423

418424
# ceil, floor, and trunc return integers for integer inputs

array_api_compat/common/_linalg.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,13 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
136136
def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
137137
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
138138

139-
def trace(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
140-
return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1, **kwargs))
139+
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
140+
if dtype is None:
141+
if x.dtype == xp.float32:
142+
dtype = xp.float64
143+
elif x.dtype == xp.complex64:
144+
dtype = xp.complex128
145+
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
141146

142147
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
143148
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',

0 commit comments

Comments
 (0)