Skip to content

Commit 318ca62

Browse files
committed
MAINT: simplify/normalize average
1 parent 6ca949d commit 318ca62

File tree

3 files changed

+37
-40
lines changed

3 files changed

+37
-40
lines changed

torch_np/_detail/_reductions.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,26 @@ def cumsum(tensor, axis, dtype=None):
292292
return result
293293

294294

295-
def average(a_tensor, axis, w_tensor):
295+
296+
def average(a, axis, weights, returned=False, keepdims=False):
297+
if weights is None:
298+
result, wsum = average_noweights(a, axis, keepdims=keepdims)
299+
else:
300+
result, wsum = average_weights(a, axis, weights, keepdims=keepdims)
301+
302+
if returned:
303+
if wsum.shape != result.shape:
304+
wsum = torch.broadcast_to(wsum, result.shape).clone()
305+
return result, wsum
306+
307+
308+
def average_noweights(a_tensor, axis, keepdims=False):
309+
result = mean(a_tensor, axis=axis, keepdims=keepdims)
310+
scl = torch.as_tensor(a_tensor.numel() / result.numel(), dtype=result.dtype)
311+
return result, scl
312+
313+
314+
def average_weights(a_tensor, axis, w_tensor, keepdims=False):
296315

297316
# dtype
298317
# FIXME: 1. use result_type
@@ -306,6 +325,9 @@ def average(a_tensor, axis, w_tensor):
306325
a_tensor = _util.cast_if_needed(a_tensor, result_dtype)
307326
w_tensor = _util.cast_if_needed(w_tensor, result_dtype)
308327

328+
# axis=None ravels, so store the originals to reuse with keepdims=True below
329+
ax, ndim = axis, a_tensor.ndim
330+
309331
# axis
310332
if axis is None:
311333
(a_tensor, w_tensor), axis = _util.axis_none_ravel(
@@ -334,6 +356,10 @@ def average(a_tensor, axis, w_tensor):
334356
denominator = w_tensor.sum(axis)
335357
result = numerator / denominator
336358

359+
# keepdims
360+
if keepdims:
361+
result = _util.apply_keepdims(result, ax, ndim)
362+
337363
return result, denominator
338364

339365

torch_np/_funcs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,9 @@ def cumprod(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=No
461461
return _helpers.result_or_out(result, out)
462462

463463

464+
cumproduct = cumprod
465+
466+
464467
@normalizer
465468
def quantile(
466469
a : ArrayLike,

torch_np/_wrapper.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -631,50 +631,16 @@ def tri(N, M=None, k=0, dtype: DTypeLike = float, *, like: SubokLike = None):
631631

632632
###### reductions
633633

634-
635-
def cumprod(a, axis=None, dtype=None, out=None):
636-
arr = asarray(a)
637-
return arr.cumprod(axis=axis, dtype=dtype, out=out)
638-
639-
640-
cumproduct = cumprod
641-
642-
643-
def cumsum(a, axis=None, dtype=None, out=None):
644-
arr = asarray(a)
645-
return arr.cumsum(axis=axis, dtype=dtype, out=out)
646-
647-
648-
def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue):
649-
650-
if weights is None:
651-
result = _funcs.mean(a, axis=axis, keepdims=keepdims)
652-
if returned:
653-
scl = result.dtype.type(a.size / result.size)
654-
return result, scl
655-
return result
656-
657-
a_tensor, w_tensor = _helpers.to_tensors(a, weights)
658-
659-
result, wsum = _reductions.average(a_tensor, axis, w_tensor)
660-
661-
# keepdims
662-
if keepdims:
663-
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
664-
665-
# returned
634+
@normalizer
635+
def average(a: ArrayLike, axis=None, weights: ArrayLike=None, returned=False, *, keepdims=NoValue):
636+
result, wsum = _reductions.average(a, axis, weights, returned=returned, keepdims=keepdims)
666637
if returned:
667-
scl = wsum
668-
if scl.shape != result.shape:
669-
scl = torch.broadcast_to(scl, result.shape).clone()
670-
671-
return _helpers.array_from(result), _helpers.array_from(scl)
672-
638+
return _helpers.tuple_arrays_from((result, wsum))
673639
else:
674640
return _helpers.array_from(result)
675641

676642

677-
# Normalizations (ArrayLike et al) are done in quantile.
643+
# Normalizations (ArrayLike et al) in percentile and median are done in `_funcs.py/quantile`.
678644
def percentile(
679645
a,
680646
q,
@@ -710,6 +676,8 @@ def outer(a: ArrayLike, b: ArrayLike, out=None):
710676
return _helpers.result_or_out(result, out)
711677

712678

679+
# ### FIXME: this is a stub
680+
713681
@normalizer
714682
def nanmean(
715683
a: ArrayLike,

0 commit comments

Comments
 (0)