Skip to content

Commit b6433ea

Browse files
mzeitlin11AlexeyGy
authored andcommitted
CLN/PERF: group quantile (pandas-dev#43510)
1 parent 01975bb commit b6433ea

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

pandas/_libs/groupby.pyi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ from typing import Literal
22

33
import numpy as np
44

5+
from pandas._typing import npt
6+
57
def group_median_float64(
68
out: np.ndarray, # ndarray[float64_t, ndim=2]
79
counts: np.ndarray, # ndarray[int64_t]
@@ -87,11 +89,12 @@ def group_ohlc(
8789
min_count: int = ...,
8890
) -> None: ...
8991
def group_quantile(
90-
out: np.ndarray, # ndarray[float64_t, ndim=2]
92+
out: npt.NDArray[np.float64],
9193
values: np.ndarray, # ndarray[numeric, ndim=1]
92-
labels: np.ndarray, # ndarray[int64_t]
93-
mask: np.ndarray, # ndarray[uint8_t]
94-
qs: np.ndarray, # const float64_t[:]
94+
labels: npt.NDArray[np.intp],
95+
mask: npt.NDArray[np.uint8],
96+
sort_indexer: npt.NDArray[np.intp], # const
97+
qs: npt.NDArray[np.float64], # const
9598
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
9699
) -> None: ...
97100
def group_last(

pandas/_libs/groupby.pyx

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
810810
ndarray[numeric, ndim=1] values,
811811
ndarray[intp_t] labels,
812812
ndarray[uint8_t] mask,
813+
const intp_t[:] sort_indexer,
813814
const float64_t[:] qs,
814815
str interpolation) -> None:
815816
"""
@@ -823,6 +824,8 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
823824
Array containing the values to apply the function against.
824825
labels : ndarray[np.intp]
825826
Array containing the unique group labels.
827+
sort_indexer : ndarray[np.intp]
828+
Indices describing sort order by values and labels.
826829
qs : ndarray[float64_t]
827830
The quantile values to search for.
828831
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
@@ -836,9 +839,9 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
836839
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
837840
Py_ssize_t grp_start=0, idx=0
838841
intp_t lab
839-
uint8_t interp
842+
InterpolationEnumType interp
840843
float64_t q_val, q_idx, frac, val, next_val
841-
ndarray[int64_t] counts, non_na_counts, sort_arr
844+
int64_t[::1] counts, non_na_counts
842845

843846
assert values.shape[0] == N
844847

@@ -873,16 +876,6 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
873876
if not mask[i]:
874877
non_na_counts[lab] += 1
875878

876-
# Get an index of values sorted by labels and then values
877-
if labels.any():
878-
# Put '-1' (NaN) labels as the last group so it does not interfere
879-
# with the calculations.
880-
labels_for_lexsort = np.where(labels == -1, labels.max() + 1, labels)
881-
else:
882-
labels_for_lexsort = labels
883-
order = (values, labels_for_lexsort)
884-
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
885-
886879
with nogil:
887880
for i in range(ngroups):
888881
# Figure out how many group elements there are
@@ -900,7 +893,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
900893
# Casting to int will intentionally truncate result
901894
idx = grp_start + <int64_t>(q_val * <float64_t>(non_na_sz - 1))
902895

903-
val = values[sort_arr[idx]]
896+
val = values[sort_indexer[idx]]
904897
# If requested quantile falls evenly on a particular index
905898
# then write that index's value out. Otherwise interpolate
906899
q_idx = q_val * (non_na_sz - 1)
@@ -909,7 +902,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
909902
if frac == 0.0 or interp == INTERPOLATION_LOWER:
910903
out[i, k] = val
911904
else:
912-
next_val = values[sort_arr[idx + 1]]
905+
next_val = values[sort_indexer[idx + 1]]
913906
if interp == INTERPOLATION_LINEAR:
914907
out[i, k] = val + (next_val - val) * frac
915908
elif interp == INTERPOLATION_HIGHER:

pandas/core/groupby/groupby.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,24 +2648,39 @@ def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
26482648
libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation
26492649
)
26502650

2651+
# Put '-1' (NaN) labels as the last group so it does not interfere
2652+
# with the calculations. Note: length check avoids failure on empty
2653+
# labels. In that case, the value doesn't matter
2654+
na_label_for_sorting = ids.max() + 1 if len(ids) > 0 else 0
2655+
labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids)
2656+
26512657
def blk_func(values: ArrayLike) -> ArrayLike:
26522658
mask = isna(values)
26532659
vals, inference = pre_processor(values)
26542660

26552661
ncols = 1
26562662
if vals.ndim == 2:
26572663
ncols = vals.shape[0]
2664+
shaped_labels = np.broadcast_to(
2665+
labels_for_lexsort, (ncols, len(labels_for_lexsort))
2666+
)
2667+
else:
2668+
shaped_labels = labels_for_lexsort
26582669

26592670
out = np.empty((ncols, ngroups, nqs), dtype=np.float64)
26602671

2672+
# Get an index of values sorted by values and then labels
2673+
order = (vals, shaped_labels)
2674+
sort_arr = np.lexsort(order).astype(np.intp, copy=False)
2675+
26612676
if vals.ndim == 1:
2662-
func(out[0], values=vals, mask=mask)
2677+
func(out[0], values=vals, mask=mask, sort_indexer=sort_arr)
26632678
else:
26642679
for i in range(ncols):
2665-
func(out[i], values=vals[i], mask=mask[i])
2680+
func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i])
26662681

26672682
if vals.ndim == 1:
2668-
out = out[0].ravel("K")
2683+
out = out.ravel("K")
26692684
else:
26702685
out = out.reshape(ncols, ngroups * nqs)
26712686
return post_processor(out, inference)

0 commit comments

Comments
 (0)