@@ -803,6 +803,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
803
803
ndarray[numeric , ndim = 1 ] values,
804
804
ndarray[intp_t] labels ,
805
805
ndarray[uint8_t] mask ,
806
+ const intp_t[:] sort_indexer ,
806
807
const float64_t[:] qs ,
807
808
str interpolation ) -> None:
808
809
"""
@@ -816,6 +817,8 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
816
817
Array containing the values to apply the function against.
817
818
labels : ndarray[np.intp]
818
819
Array containing the unique group labels.
820
+ sort_indexer : ndarray[np.intp]
821
+ Indices describing sort order by values and labels.
819
822
qs : ndarray[float64_t]
820
823
The quantile values to search for.
821
824
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
@@ -829,9 +832,9 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
829
832
Py_ssize_t i , N = len (labels), ngroups , grp_sz , non_na_sz , k , nqs
830
833
Py_ssize_t grp_start = 0 , idx = 0
831
834
intp_t lab
832
- uint8_t interp
835
+ InterpolationEnumType interp
833
836
float64_t q_val , q_idx , frac , val , next_val
834
- ndarray[ int64_t] counts , non_na_counts , sort_arr
837
+ int64_t[::1 ] counts , non_na_counts
835
838
836
839
assert values.shape[0] == N
837
840
@@ -866,16 +869,6 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
866
869
if not mask[i]:
867
870
non_na_counts[lab] += 1
868
871
869
- # Get an index of values sorted by labels and then values
870
- if labels.any():
871
- # Put '-1' (NaN) labels as the last group so it does not interfere
872
- # with the calculations.
873
- labels_for_lexsort = np.where(labels == - 1 , labels.max() + 1 , labels)
874
- else :
875
- labels_for_lexsort = labels
876
- order = (values, labels_for_lexsort)
877
- sort_arr = np.lexsort(order).astype(np.int64, copy = False )
878
-
879
872
with nogil:
880
873
for i in range (ngroups):
881
874
# Figure out how many group elements there are
@@ -893,7 +886,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
893
886
# Casting to int will intentionally truncate result
894
887
idx = grp_start + < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
895
888
896
- val = values[sort_arr [idx]]
889
+ val = values[sort_indexer [idx]]
897
890
# If requested quantile falls evenly on a particular index
898
891
# then write that index's value out. Otherwise interpolate
899
892
q_idx = q_val * (non_na_sz - 1 )
@@ -902,7 +895,7 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
902
895
if frac == 0.0 or interp == INTERPOLATION_LOWER:
903
896
out[i, k] = val
904
897
else :
905
- next_val = values[sort_arr [idx + 1 ]]
898
+ next_val = values[sort_indexer [idx + 1 ]]
906
899
if interp == INTERPOLATION_LINEAR:
907
900
out[i, k] = val + (next_val - val) * frac
908
901
elif interp == INTERPOLATION_HIGHER:
0 commit comments