Skip to content

Commit e5ac01c

Browse files
mzeitlin11AlexeyGy
authored andcommitted
CLN/PERF: group quantile (pandas-dev#43510)
1 parent d893131 commit e5ac01c

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

pandas/_libs/groupby.pyi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ from typing import Literal
22

33
import numpy as np
44

5+
from pandas._typing import npt
6+
57
def group_median_float64(
68
out: np.ndarray, # ndarray[float64_t, ndim=2]
79
counts: np.ndarray, # ndarray[int64_t]
@@ -85,11 +87,12 @@ def group_ohlc(
8587
min_count: int = ...,
8688
) -> None: ...
8789
def group_quantile(
88-
out: np.ndarray, # ndarray[float64_t, ndim=2]
90+
out: npt.NDArray[np.float64],
8991
values: np.ndarray, # ndarray[numeric, ndim=1]
90-
labels: np.ndarray, # ndarray[int64_t]
91-
mask: np.ndarray, # ndarray[uint8_t]
92-
qs: np.ndarray, # const float64_t[:]
92+
labels: npt.NDArray[np.intp],
93+
mask: npt.NDArray[np.uint8],
94+
sort_indexer: npt.NDArray[np.intp], # const
95+
qs: npt.NDArray[np.float64], # const
9396
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
9497
) -> None: ...
9598
def group_last(

pandas/_libs/groupby.pyx

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
803803
ndarray[numeric, ndim=1] values,
804804
ndarray[intp_t] labels,
805805
ndarray[uint8_t] mask,
806+
const intp_t[:] sort_indexer,
806807
const float64_t[:] qs,
807808
str interpolation) -> None:
808809
"""
@@ -816,6 +817,8 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
816817
Array containing the values to apply the function against.
817818
labels : ndarray[np.intp]
818819
Array containing the unique group labels.
820+
sort_indexer : ndarray[np.intp]
821+
Indices describing sort order by values and labels.
819822
qs : ndarray[float64_t]
820823
The quantile values to search for.
821824
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
@@ -829,9 +832,9 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
829832
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
830833
Py_ssize_t grp_start=0, idx=0
831834
intp_t lab
832-
uint8_t interp
835+
InterpolationEnumType interp
833836
float64_t q_val, q_idx, frac, val, next_val
834-
ndarray[int64_t] counts, non_na_counts, sort_arr
837+
int64_t[::1] counts, non_na_counts
835838

836839
assert values.shape[0] == N
837840

@@ -866,16 +869,6 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
866869
if not mask[i]:
867870
non_na_counts[lab] += 1
868871

869-
# Get an index of values sorted by labels and then values
870-
if labels.any():
871-
# Put '-1' (NaN) labels as the last group so it does not interfere
872-
# with the calculations.
873-
labels_for_lexsort = np.where(labels == -1, labels.max() + 1, labels)
874-
else:
875-
labels_for_lexsort = labels
876-
order = (values, labels_for_lexsort)
877-
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
878-
879872
with nogil:
880873
for i in range(ngroups):
881874
# Figure out how many group elements there are
@@ -893,7 +886,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
893886
# Casting to int will intentionally truncate result
894887
idx = grp_start + <int64_t>(q_val * <float64_t>(non_na_sz - 1))
895888

896-
val = values[sort_arr[idx]]
889+
val = values[sort_indexer[idx]]
897890
# If requested quantile falls evenly on a particular index
898891
# then write that index's value out. Otherwise interpolate
899892
q_idx = q_val * (non_na_sz - 1)
@@ -902,7 +895,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
902895
if frac == 0.0 or interp == INTERPOLATION_LOWER:
903896
out[i, k] = val
904897
else:
905-
next_val = values[sort_arr[idx + 1]]
898+
next_val = values[sort_indexer[idx + 1]]
906899
if interp == INTERPOLATION_LINEAR:
907900
out[i, k] = val + (next_val - val) * frac
908901
elif interp == INTERPOLATION_HIGHER:

pandas/core/groupby/groupby.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,24 +2648,39 @@ def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
26482648
libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation
26492649
)
26502650

2651+
# Put '-1' (NaN) labels as the last group so it does not interfere
2652+
# with the calculations. Note: length check avoids failure on empty
2653+
# labels. In that case, the value doesn't matter
2654+
na_label_for_sorting = ids.max() + 1 if len(ids) > 0 else 0
2655+
labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids)
2656+
26512657
def blk_func(values: ArrayLike) -> ArrayLike:
26522658
mask = isna(values)
26532659
vals, inference = pre_processor(values)
26542660

26552661
ncols = 1
26562662
if vals.ndim == 2:
26572663
ncols = vals.shape[0]
2664+
shaped_labels = np.broadcast_to(
2665+
labels_for_lexsort, (ncols, len(labels_for_lexsort))
2666+
)
2667+
else:
2668+
shaped_labels = labels_for_lexsort
26582669

26592670
out = np.empty((ncols, ngroups, nqs), dtype=np.float64)
26602671

2672+
# Get an index of values sorted by values and then labels
2673+
order = (vals, shaped_labels)
2674+
sort_arr = np.lexsort(order).astype(np.intp, copy=False)
2675+
26612676
if vals.ndim == 1:
2662-
func(out[0], values=vals, mask=mask)
2677+
func(out[0], values=vals, mask=mask, sort_indexer=sort_arr)
26632678
else:
26642679
for i in range(ncols):
2665-
func(out[i], values=vals[i], mask=mask[i])
2680+
func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i])
26662681

26672682
if vals.ndim == 1:
2668-
out = out[0].ravel("K")
2683+
out = out.ravel("K")
26692684
else:
26702685
out = out.reshape(ncols, ngroups * nqs)
26712686
return post_processor(out, inference)

0 commit comments

Comments
 (0)