Skip to content

Commit 6ca949d

Browse files
committed
MAINT: quantile/percentile/median
Deviate from the standard argument handling: - percentile and median delegate to quantile at the wrapper level; this is just less code - keepdims=True handling is inline in _reductions.py::quantile, not in a decorator. The standard decorator expects the axis as the second argument, while here it is the third one. Can be fixed, but seems to be more hassle then worth TBH.
1 parent 0b36f2d commit 6ca949d

File tree

3 files changed

+32
-25
lines changed

3 files changed

+32
-25
lines changed

torch_np/_detail/_reductions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def average(a_tensor, axis, w_tensor):
337337
return result, denominator
338338

339339

340-
def quantile(a_tensor, q_tensor, axis, method):
340+
def quantile(a_tensor, q_tensor, axis, method, keepdims=False):
341341

342342
if (0 > q_tensor).any() or (q_tensor > 1).any():
343343
raise ValueError("Quantiles must be in range [0, 1], got %s" % q_tensor)
@@ -350,15 +350,24 @@ def quantile(a_tensor, q_tensor, axis, method):
350350
if a_tensor.dtype == torch.float16:
351351
a_tensor = a_tensor.to(torch.float32)
352352

353+
# TODO: consider moving this normalize_axis_tuple dance to normalize axis? Across the board if at all.
353354
# axis
354355
if axis is not None:
355356
axis = _util.normalize_axis_tuple(axis, a_tensor.ndim)
356357
axis = _util.allow_only_single_axis(axis)
357358

358359
q_tensor = _util.cast_if_needed(q_tensor, a_tensor.dtype)
359360

361+
362+
# axis=None ravels, so store the originals to reuse with keepdims=True below
363+
ax, ndim = axis, a_tensor.ndim
360364
(a_tensor, q_tensor), axis = _util.axis_none_ravel(a_tensor, q_tensor, axis=axis)
361365

362366
result = torch.quantile(a_tensor, q_tensor, axis=axis, interpolation=method)
363367

368+
# NB: not using @emulate_keepdims here because the signature is (a, q, axis, ...)
369+
# while the decorator expects (a, axis, ...)
370+
# this can be fixed, of course, but the cure seems worse then the desease
371+
if keepdims:
372+
result = _util.apply_keepdims(result, ax, ndim)
364373
return result

torch_np/_funcs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,22 @@ def cumsum(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=Non
459459
def cumprod(a: ArrayLike, axis: AxisLike = None, dtype: DTypeLike = None, out=None):
460460
result = _reductions.cumprod(a, axis=axis, dtype=dtype)
461461
return _helpers.result_or_out(result, out)
462+
463+
464+
@normalizer
465+
def quantile(
466+
a : ArrayLike,
467+
q : ArrayLike,
468+
axis: AxisLike=None,
469+
out=None,
470+
overwrite_input=False,
471+
method="linear",
472+
keepdims=False,
473+
*,
474+
interpolation=None,
475+
):
476+
if interpolation is not None:
477+
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
478+
479+
result = _reductions.quantile(a, q, axis, method=method, keepdims=keepdims)
480+
return _helpers.result_or_out(result, out, promote_scalar=True)

torch_np/_wrapper.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue):
674674
return _helpers.array_from(result)
675675

676676

677+
# Normalizations (ArrayLike et al) are done in quantile.
677678
def percentile(
678679
a,
679680
q,
@@ -685,36 +686,14 @@ def percentile(
685686
*,
686687
interpolation=None,
687688
):
688-
return quantile(
689+
return _funcs.quantile(
689690
a, asarray(q) / 100.0, axis, out, overwrite_input, method, keepdims=keepdims
690691
)
691692

692693

693-
def quantile(
694-
a,
695-
q,
696-
axis=None,
697-
out=None,
698-
overwrite_input=False,
699-
method="linear",
700-
keepdims=False,
701-
*,
702-
interpolation=None,
703-
):
704-
if interpolation is not None:
705-
raise ValueError("'interpolation' argument is deprecated; use 'method' instead")
706-
707-
a_tensor, q_tensor = _helpers.to_tensors(a, q)
708-
result = _reductions.quantile(a_tensor, q_tensor, axis, method)
709-
710-
# keepdims
711-
if keepdims:
712-
result = _util.apply_keepdims(result, axis, a_tensor.ndim)
713-
return _helpers.result_or_out(result, out, promote_scalar=True)
714-
715694

716695
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
717-
return quantile(
696+
return _funcs.quantile(
718697
a, 0.5, axis=axis, overwrite_input=overwrite_input, out=out, keepdims=keepdims
719698
)
720699

0 commit comments

Comments
 (0)