Skip to content

Commit eae5033

Browse files
committed
MAINT: postprocess out= returns via a return annotation
Other returns are wrapped automagically, based on the return type
1 parent 91387f6 commit eae5033

File tree

6 files changed

+102
-91
lines changed

6 files changed

+102
-91
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import _helpers
66
from ._detail import _binary_ufuncs
7-
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
7+
from ._normalizations import ArrayLike, DTypeLike, NDArray, OutArray, SubokLike, normalizer
88

99
__all__ = [
1010
name
@@ -33,7 +33,7 @@ def wrapped(
3333
subok: SubokLike = False,
3434
signature=None,
3535
extobj=None,
36-
):
36+
) -> OutArray:
3737
tensors = _helpers.ufunc_preprocess(
3838
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
3939
)
@@ -44,7 +44,7 @@ def wrapped(
4444
tensors = tuple(torch.broadcast_to(t, shape) for t in tensors)
4545

4646
result = torch_func(*tensors)
47-
return _helpers.result_or_out(result, out)
47+
return result, out
4848

4949
return wrapped
5050

@@ -68,7 +68,7 @@ def matmul(
6868
extobj=None,
6969
axes=None,
7070
axis=None,
71-
):
71+
) -> OutArray:
7272
tensors = _helpers.ufunc_preprocess(
7373
(x1, x2), out, True, casting, order, dtype, subok, signature, extobj
7474
)
@@ -77,7 +77,7 @@ def matmul(
7777

7878
# NB: do not broadcast input tensors against the out=... array
7979
result = _binary_ufuncs.matmul(*tensors)
80-
return _helpers.result_or_out(result, out)
80+
return result, out
8181

8282

8383
#

torch_np/_decorators.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

torch_np/_funcs.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AxisLike,
1111
DTypeLike,
1212
NDArray,
13+
OutArray,
1314
SubokLike,
1415
normalizer,
1516
)
@@ -41,11 +42,11 @@ def clip(
4142
min: Optional[ArrayLike] = None,
4243
max: Optional[ArrayLike] = None,
4344
out: Optional[NDArray] = None,
44-
):
45+
) -> OutArray:
4546
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
4647
# one of them to be None. Follow the more lax version.
4748
result = _impl.clip(a, min, max)
48-
return _helpers.result_or_out(result, out)
49+
return result, out
4950

5051

5152
@normalizer
@@ -78,9 +79,9 @@ def trace(
7879
axis2=1,
7980
dtype: DTypeLike = None,
8081
out: Optional[NDArray] = None,
81-
):
82+
) -> OutArray:
8283
result = _impl.trace(a, offset, axis1, axis2, dtype)
83-
return _helpers.result_or_out(result, out)
84+
return result, out
8485

8586

8687
@normalizer
@@ -133,9 +134,9 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
133134

134135

135136
@normalizer
136-
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
137+
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None) -> OutArray:
137138
result = _impl.dot(a, b)
138-
return _helpers.result_or_out(result, out)
139+
return result, out
139140

140141

141142
# ### sort and partition ###
@@ -232,9 +233,9 @@ def imag(a: ArrayLike):
232233

233234

234235
@normalizer
235-
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None):
236+
def round_(a: ArrayLike, decimals=0, out: Optional[NDArray] = None) -> OutArray:
236237
result = _impl.round(a, decimals)
237-
return _helpers.result_or_out(result, out)
238+
return result, out
238239

239240

240241
around = round_
@@ -253,11 +254,11 @@ def sum(
253254
keepdims=NoValue,
254255
initial=NoValue,
255256
where=NoValue,
256-
):
257+
) -> OutArray:
257258
result = _impl.sum(
258259
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
259260
)
260-
return _helpers.result_or_out(result, out)
261+
return result, out
261262

262263

263264
@normalizer
@@ -269,11 +270,11 @@ def prod(
269270
keepdims=NoValue,
270271
initial=NoValue,
271272
where=NoValue,
272-
):
273+
) -> OutArray:
273274
result = _impl.prod(
274275
a, axis=axis, dtype=dtype, initial=initial, where=where, keepdims=keepdims
275276
)
276-
return _helpers.result_or_out(result, out)
277+
return result, out
277278

278279

279280
product = prod
@@ -288,9 +289,9 @@ def mean(
288289
keepdims=NoValue,
289290
*,
290291
where=NoValue,
291-
):
292+
) -> OutArray:
292293
result = _impl.mean(a, axis=axis, dtype=dtype, where=NoValue, keepdims=keepdims)
293-
return _helpers.result_or_out(result, out)
294+
return result, out
294295

295296

296297
@normalizer
@@ -303,11 +304,11 @@ def var(
303304
keepdims=NoValue,
304305
*,
305306
where=NoValue,
306-
):
307+
) -> OutArray:
307308
result = _impl.var(
308309
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
309310
)
310-
return _helpers.result_or_out(result, out)
311+
return result, out
311312

312313

313314
@normalizer
@@ -320,11 +321,11 @@ def std(
320321
keepdims=NoValue,
321322
*,
322323
where=NoValue,
323-
):
324+
) -> OutArray:
324325
result = _impl.std(
325326
a, axis=axis, dtype=dtype, ddof=ddof, where=where, keepdims=keepdims
326327
)
327-
return _helpers.result_or_out(result, out)
328+
return result, out
328329

329330

330331
@normalizer
@@ -334,9 +335,9 @@ def argmin(
334335
out: Optional[NDArray] = None,
335336
*,
336337
keepdims=NoValue,
337-
):
338+
) -> OutArray:
338339
result = _impl.argmin(a, axis=axis, keepdims=keepdims)
339-
return _helpers.result_or_out(result, out)
340+
return result, out
340341

341342

342343
@normalizer
@@ -346,9 +347,9 @@ def argmax(
346347
out: Optional[NDArray] = None,
347348
*,
348349
keepdims=NoValue,
349-
):
350+
) -> OutArray:
350351
result = _impl.argmax(a, axis=axis, keepdims=keepdims)
351-
return _helpers.result_or_out(result, out)
352+
return result, out
352353

353354

354355
@normalizer
@@ -359,9 +360,9 @@ def amax(
359360
keepdims=NoValue,
360361
initial=NoValue,
361362
where=NoValue,
362-
):
363+
) -> OutArray:
363364
result = _impl.max(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
364-
return _helpers.result_or_out(result, out)
365+
return result, out
365366

366367

367368
max = amax
@@ -375,9 +376,9 @@ def amin(
375376
keepdims=NoValue,
376377
initial=NoValue,
377378
where=NoValue,
378-
):
379+
) -> OutArray:
379380
result = _impl.min(a, axis=axis, initial=initial, where=where, keepdims=keepdims)
380-
return _helpers.result_or_out(result, out)
381+
return result, out
381382

382383

383384
min = amin
@@ -386,9 +387,9 @@ def amin(
386387
@normalizer
387388
def ptp(
388389
a: ArrayLike, axis: AxisLike = None, out: Optional[NDArray] = None, keepdims=NoValue
389-
):
390+
) -> OutArray:
390391
result = _impl.ptp(a, axis=axis, keepdims=keepdims)
391-
return _helpers.result_or_out(result, out)
392+
return result, out
392393

393394

394395
@normalizer
@@ -399,9 +400,9 @@ def all(
399400
keepdims=NoValue,
400401
*,
401402
where=NoValue,
402-
):
403+
) -> OutArray:
403404
result = _impl.all(a, axis=axis, where=where, keepdims=keepdims)
404-
return _helpers.result_or_out(result, out)
405+
return result, out
405406

406407

407408
@normalizer
@@ -412,9 +413,9 @@ def any(
412413
keepdims=NoValue,
413414
*,
414415
where=NoValue,
415-
):
416+
) -> OutArray:
416417
result = _impl.any(a, axis=axis, where=where, keepdims=keepdims)
417-
return _helpers.result_or_out(result, out)
418+
return result, out
418419

419420

420421
@normalizer
@@ -429,9 +430,9 @@ def cumsum(
429430
axis: AxisLike = None,
430431
dtype: DTypeLike = None,
431432
out: Optional[NDArray] = None,
432-
):
433+
) -> OutArray:
433434
result = _impl.cumsum(a, axis=axis, dtype=dtype)
434-
return _helpers.result_or_out(result, out)
435+
return result, out
435436

436437

437438
@normalizer
@@ -440,9 +441,9 @@ def cumprod(
440441
axis: AxisLike = None,
441442
dtype: DTypeLike = None,
442443
out: Optional[NDArray] = None,
443-
):
444+
) -> OutArray:
444445
result = _impl.cumprod(a, axis=axis, dtype=dtype)
445-
return _helpers.result_or_out(result, out)
446+
return result, out
446447

447448

448449
cumproduct = cumprod
@@ -459,7 +460,7 @@ def quantile(
459460
keepdims=False,
460461
*,
461462
interpolation=None,
462-
):
463+
) -> OutArray:
463464
result = _impl.quantile(
464465
a,
465466
q,
@@ -469,10 +470,10 @@ def quantile(
469470
keepdims=keepdims,
470471
interpolation=interpolation,
471472
)
472-
return _helpers.result_or_out(result, out, promote_scalar=True)
473+
return result, out
473474

474475

475-
@normalizer
476+
@normalizer(promote_scalar_result=True)
476477
def percentile(
477478
a: ArrayLike,
478479
q: ArrayLike,
@@ -483,7 +484,7 @@ def percentile(
483484
keepdims=False,
484485
*,
485486
interpolation=None,
486-
):
487+
) -> OutArray:
487488
result = _impl.percentile(
488489
a,
489490
q,
@@ -493,7 +494,7 @@ def percentile(
493494
keepdims=keepdims,
494495
interpolation=interpolation,
495496
)
496-
return _helpers.result_or_out(result, out, promote_scalar=True)
497+
return result, out
497498

498499

499500
def median(

0 commit comments

Comments
 (0)