Skip to content

Commit d77a082

Browse files
committed
MAINT: correct tuple/Optional annotation, fix up rebase
1 parent 8fbb43f commit d77a082

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def divmod(
8686
out1: Optional[NDArray] = None,
8787
out2: Optional[NDArray] = None,
8888
/,
89-
out: Optional[tuple[NDArray]] = (None, None),
89+
out: tuple[Optional[NDArray], Optional[NDArray]] = (None, None),
9090
*,
9191
where=True,
9292
casting="same_kind",

torch_np/_funcs.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def concatenate(
111111
out: Optional[NDArray] = None,
112112
dtype: DTypeLike = None,
113113
casting="same_kind",
114-
) -> OutArray:
114+
):
115115
_concat_check(ar_tuple, dtype, out=out)
116116
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
117-
return result, out
117+
return result
118118

119119

120120
@normalizer
@@ -163,7 +163,7 @@ def stack(
163163
*,
164164
dtype: DTypeLike = None,
165165
casting="same_kind",
166-
) -> OutArray:
166+
):
167167
_concat_check(arrays, dtype, out=out)
168168

169169
tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
@@ -173,7 +173,7 @@ def stack(
173173
result = torch.stack(tensors, axis=axis)
174174
except RuntimeError as e:
175175
raise ValueError(*e.args)
176-
return result, out
176+
return result
177177

178178

179179
# ### split ###
@@ -1013,7 +1013,7 @@ def clip(
10131013
# one of them to be None. Follow the more lax version.
10141014
if min is None and max is None:
10151015
raise ValueError("One of max or min must be given")
1016-
result = torch.clamp(min, max)
1016+
result = torch.clamp(a, min, max)
10171017
return result
10181018

10191019

@@ -1194,6 +1194,12 @@ def inner(a: ArrayLike, b: ArrayLike, /):
11941194
result = result.to(torch.bool)
11951195
return result
11961196

1197+
1198+
@normalizer
1199+
def outer(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
1200+
return torch.outer(a, b)
1201+
1202+
11971203
# ### sort and partition ###
11981204

11991205

@@ -1371,7 +1377,7 @@ def imag(a: ArrayLike):
13711377

13721378

13731379
@normalizer
1374-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None) -> OutArray:
1380+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
13751381
if a.is_floating_point():
13761382
result = torch.round(a, decimals=decimals)
13771383
elif a.is_complex():
@@ -1730,7 +1736,7 @@ def imag(a: ArrayLike):
17301736

17311737

17321738
@normalizer
1733-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None) -> OutArray:
1739+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
17341740
if a.is_floating_point():
17351741
result = torch.round(a, decimals=decimals)
17361742
elif a.is_complex():
@@ -1742,7 +1748,7 @@ def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None) -> OutArray:
17421748
else:
17431749
# RuntimeError: "round_cpu" not implemented for 'int'
17441750
result = a
1745-
return result, out
1751+
return result
17461752

17471753

17481754
around = round_

0 commit comments

Comments
 (0)