Skip to content

Commit 310929d

Browse files
committed
Fix casting for the array API concat() and stack()
1 parent 6789a74 commit 310929d

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

numpy/array_api/_manipulation_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i
1414
1515
See its docstring for more information.
1616
"""
17+
# Note: Casting rules here are different from the np.concatenate default
18+
# (no for scalars with axis=None, no cross-kind casting)
19+
dtype = result_type(*arrays)
1720
arrays = tuple(a._array for a in arrays)
18-
# Call result type here just to raise on disallowed type combinations
19-
result_type(*arrays)
20-
return Array._new(np.concatenate(arrays, axis=axis))
21+
return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
2122

2223
def expand_dims(x: Array, /, *, axis: int) -> Array:
2324
"""
@@ -65,7 +66,7 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
6566
6667
See its docstring for more information.
6768
"""
68-
arrays = tuple(a._array for a in arrays)
6969
# Call result type here just to raise on disallowed type combinations
7070
result_type(*arrays)
71+
arrays = tuple(a._array for a in arrays)
7172
return Array._new(np.stack(arrays, axis=axis))

0 commit comments

Comments
 (0)