-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
Add interpolation options to rolling quantile #20497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
f31a448
7afa6b5
a02183a
412eb98
f5fb6cb
4f80369
6d5c77d
dc5e74d
5280f95
2ac734c
f21d21a
f9b1e7e
3a2e431
b986a5d
fd9568a
9cc7a71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1357,77 +1357,129 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp, | |
return output | ||
|
||
|
||
cdef enum InterpolationType: | ||
LINEAR, | ||
LOWER, | ||
HIGHER, | ||
NEAREST, | ||
MIDPOINT | ||
|
||
|
||
interpolation_types = { | ||
'linear': LINEAR, | ||
'lower': LOWER, | ||
'higher': HIGHER, | ||
'nearest': NEAREST, | ||
'midpoint': MIDPOINT, | ||
} | ||
|
||
|
||
def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win, | ||
int64_t minp, object index, object closed, | ||
double quantile): | ||
double quantile, str interpolation): | ||
""" | ||
O(N log(window)) implementation using skip list | ||
""" | ||
cdef: | ||
double val, prev, midpoint | ||
IndexableSkiplist skiplist | ||
double val, prev, midpoint, idx_with_fraction | ||
skiplist_t *skiplist | ||
int64_t nobs = 0, i, j, s, e, N | ||
Py_ssize_t idx | ||
bint is_variable | ||
ndarray[int64_t] start, end | ||
ndarray[double_t] output | ||
double vlow, vhigh | ||
InterpolationType interpolation_type | ||
int ret = 0 | ||
|
||
if quantile <= 0.0 or quantile >= 1.0: | ||
raise ValueError("quantile value {0} not in [0, 1]".format(quantile)) | ||
|
||
try: | ||
interpolation_type = interpolation_types[interpolation] | ||
except KeyError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a test case to cover that this raises the expected error message when passing an invalid argument? If not can you add? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor nit but can you place the name of the passed interpolation in single quotes? Helps distinguish it from the rest of the text in the error message (will need to update test as well) |
||
raise ValueError("Interpolation '{}' is not supported" | ||
.format(interpolation)) | ||
|
||
# we use the Fixed/Variable Indexer here as the | ||
# actual skiplist ops outweigh any window computation costs | ||
start, end, N, win, minp, is_variable = get_window_indexer( | ||
input, win, | ||
minp, index, closed, | ||
use_mock=False) | ||
output = np.empty(N, dtype=float) | ||
skiplist = IndexableSkiplist(win) | ||
|
||
for i in range(0, N): | ||
s = start[i] | ||
e = end[i] | ||
|
||
if i == 0: | ||
|
||
# setup | ||
val = input[i] | ||
if val == val: | ||
nobs += 1 | ||
skiplist.insert(val) | ||
skiplist = skiplist_init(<int>win) | ||
if skiplist == NULL: | ||
raise MemoryError("skiplist_init failed") | ||
|
||
else: | ||
with nogil: | ||
for i in range(0, N): | ||
s = start[i] | ||
e = end[i] | ||
|
||
# calculate deletes | ||
for j in range(start[i - 1], s): | ||
val = input[j] | ||
if val == val: | ||
skiplist.remove(val) | ||
nobs -= 1 | ||
if i == 0: | ||
|
||
# calculate adds | ||
for j in range(end[i - 1], e): | ||
val = input[j] | ||
# setup | ||
val = input[i] | ||
if val == val: | ||
nobs += 1 | ||
skiplist.insert(val) | ||
skiplist_insert(skiplist, val) | ||
|
||
if nobs >= minp: | ||
idx = int(quantile * <double>(nobs - 1)) | ||
else: | ||
|
||
# Single value in skip list | ||
if nobs == 1: | ||
output[i] = skiplist.get(0) | ||
# calculate deletes | ||
for j in range(start[i - 1], s): | ||
val = input[j] | ||
if val == val: | ||
skiplist_remove(skiplist, val) | ||
nobs -= 1 | ||
|
||
# Interpolated quantile | ||
# calculate adds | ||
for j in range(end[i - 1], e): | ||
val = input[j] | ||
if val == val: | ||
nobs += 1 | ||
skiplist_insert(skiplist, val) | ||
|
||
if nobs >= minp: | ||
if nobs == 1: | ||
# Single value in skip list | ||
output[i] = skiplist_get(skiplist, 0, &ret) | ||
else: | ||
idx_with_fraction = quantile * (nobs - 1) | ||
idx = <int> idx_with_fraction | ||
|
||
if idx_with_fraction == idx: | ||
# no need to interpolate | ||
output[i] = skiplist_get(skiplist, idx, &ret) | ||
continue | ||
|
||
if interpolation_type == LINEAR: | ||
vlow = skiplist_get(skiplist, idx, &ret) | ||
vhigh = skiplist_get(skiplist, idx + 1, &ret) | ||
output[i] = ((vlow + (vhigh - vlow) * | ||
(idx_with_fraction - idx))) | ||
elif interpolation_type == LOWER: | ||
output[i] = skiplist_get(skiplist, idx, &ret) | ||
elif interpolation_type == HIGHER: | ||
output[i] = skiplist_get(skiplist, idx + 1, &ret) | ||
elif interpolation_type == NEAREST: | ||
# the same behaviour as round() | ||
if idx_with_fraction - idx == 0.5: | ||
if idx % 2 == 0: | ||
output[i] = skiplist_get(skiplist, idx, &ret) | ||
else: | ||
output[i] = skiplist_get(skiplist, idx + 1, &ret) | ||
elif idx_with_fraction - idx < 0.5: | ||
output[i] = skiplist_get(skiplist, idx, &ret) | ||
else: | ||
output[i] = skiplist_get(skiplist, idx + 1, &ret) | ||
elif interpolation_type == MIDPOINT: | ||
vlow = skiplist_get(skiplist, idx, &ret) | ||
vhigh = skiplist_get(skiplist, idx + 1, &ret) | ||
output[i] = <double> (vlow + vhigh) / 2 | ||
else: | ||
vlow = skiplist.get(idx) | ||
vhigh = skiplist.get(idx + 1) | ||
output[i] = ((vlow + (vhigh - vlow) * | ||
(quantile * (nobs - 1) - idx))) | ||
else: | ||
output[i] = NaN | ||
output[i] = NaN | ||
|
||
return output | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1276,9 +1276,53 @@ def kurt(self, **kwargs): | |
Parameters | ||
---------- | ||
quantile : float | ||
0 <= quantile <= 1""") | ||
0 <= quantile <= 1 | ||
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} | ||
.. versionadded:: 0.23.0 | ||
|
||
This optional parameter specifies the interpolation method to use, | ||
when the desired quantile lies between two data points `i` and `j`: | ||
|
||
* linear: `i + (j - i) * fraction`, where `fraction` is the | ||
fractional part of the index surrounded by `i` and `j`. | ||
* lower: `i`. | ||
* higher: `j`. | ||
* nearest: `i` or `j` whichever is nearest. | ||
* midpoint: (`i` + `j`) / 2. | ||
|
||
Returns | ||
------- | ||
Series or DataFrame | ||
Returned object type is determined by the caller of the %(name)s | ||
calculation. | ||
|
||
Examples | ||
-------- | ||
>>> s = Series([1, 2, 3, 4]) | ||
>>> s.rolling(2).quantile(.4, interpolation='lower') | ||
0 NaN | ||
1 1.0 | ||
2 2.0 | ||
3 3.0 | ||
dtype: float64 | ||
|
||
>>> s.rolling(2).quantile(.4, interpolation='midpoint') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a blank line here |
||
0 NaN | ||
1 1.5 | ||
2 2.5 | ||
3 3.5 | ||
dtype: float64 | ||
|
||
See Also | ||
-------- | ||
pandas.Series.quantile : Computes value at the given quantile over all data | ||
in Series. | ||
pandas.DataFrame.quantile : Computes values at the given quantile over | ||
requested axis in DataFrame. | ||
|
||
""") | ||
|
||
def quantile(self, quantile, **kwargs): | ||
def quantile(self, quantile, interpolation='linear', **kwargs): | ||
window = self._get_window() | ||
index, indexi = self._get_index() | ||
|
||
|
@@ -1292,7 +1336,8 @@ def f(arg, *args, **kwargs): | |
self.closed) | ||
else: | ||
return _window.roll_quantile(arg, window, minp, indexi, | ||
self.closed, quantile) | ||
self.closed, quantile, | ||
interpolation) | ||
|
||
return self._apply(f, 'quantile', quantile=quantile, | ||
**kwargs) | ||
|
@@ -1613,8 +1658,10 @@ def kurt(self, **kwargs): | |
@Substitution(name='rolling') | ||
@Appender(_doc_template) | ||
@Appender(_shared_docs['quantile']) | ||
def quantile(self, quantile, **kwargs): | ||
return super(Rolling, self).quantile(quantile=quantile, **kwargs) | ||
def quantile(self, quantile, interpolation='linear', **kwargs): | ||
return super(Rolling, self).quantile(quantile=quantile, | ||
interpolation=interpolation, | ||
**kwargs) | ||
|
||
@Substitution(name='rolling') | ||
@Appender(_doc_template) | ||
|
@@ -1872,8 +1919,10 @@ def kurt(self, **kwargs): | |
@Substitution(name='expanding') | ||
@Appender(_doc_template) | ||
@Appender(_shared_docs['quantile']) | ||
def quantile(self, quantile, **kwargs): | ||
return super(Expanding, self).quantile(quantile=quantile, **kwargs) | ||
def quantile(self, quantile, interpolation='linear', **kwargs): | ||
return super(Expanding, self).quantile(quantile=quantile, | ||
interpolation=interpolation, | ||
**kwargs) | ||
|
||
@Substitution(name='expanding') | ||
@Appender(_doc_template) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from datetime import datetime, timedelta | ||
from numpy.random import randn | ||
import numpy as np | ||
from pandas import _np_version_under1p12 | ||
|
||
import pandas as pd | ||
from pandas import (Series, DataFrame, bdate_range, | ||
|
@@ -1166,15 +1167,40 @@ def test_rolling_quantile_np_percentile(self): | |
|
||
tm.assert_almost_equal(df_quantile.values, np.array(np_percentile)) | ||
|
||
def test_rolling_quantile_series(self): | ||
# #16211: Tests that rolling window's quantile default behavior | ||
# is analogus to Series' quantile | ||
arr = np.arange(100) | ||
s = Series(arr) | ||
q1 = s.quantile(0.1) | ||
q2 = s.rolling(100).quantile(0.1).iloc[-1] | ||
@pytest.mark.skipif(_np_version_under1p12, | ||
reason='numpy midpoint interpolation is broken') | ||
@pytest.mark.parametrize('quantile', [0.0, 0.1, 0.45, 0.5, 1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add examples with NA data? |
||
@pytest.mark.parametrize('interpolation', ['linear', 'lower', 'higher', | ||
'nearest', 'midpoint']) | ||
@pytest.mark.parametrize('data', [[1., 2., 3., 4., 5., 6., 7.], | ||
[8., 1., 3., 4., 5., 2., 6., 7.], | ||
[0., np.nan, 0.2, np.nan, 0.4], | ||
[np.nan, np.nan, np.nan, np.nan], | ||
[np.nan, 0.1, np.nan, 0.3, 0.4, 0.5], | ||
[0.5], [np.nan, 0.7, 0.6]]) | ||
def test_rolling_quantile_interpolation_options(self, quantile, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this and the above test cover practically the same use case. If that's the case I'd get rid of the test above |
||
interpolation, data): | ||
# Tests that rolling window's quantile behavior is analogous to | ||
# Series' quantile for each interpolation option | ||
s = Series(data) | ||
|
||
q1 = s.quantile(quantile, interpolation) | ||
q2 = s.expanding(min_periods=1).quantile( | ||
quantile, interpolation).iloc[-1] | ||
|
||
if np.isnan(q1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WillAyd, it is not series type. I edited test data a bit and was able to get rid of |
||
assert np.isnan(q2) | ||
else: | ||
assert q1 == q2 | ||
|
||
def test_invalid_quantile_value(self): | ||
data = np.arange(5) | ||
s = Series(data) | ||
|
||
tm.assert_almost_equal(q1, q2) | ||
with pytest.raises(ValueError, match="Interpolation 'invalid'" | ||
" is not supported"): | ||
s.rolling(len(data), min_periods=1).quantile( | ||
0.5, interpolation='invalid') | ||
|
||
def test_rolling_quantile_param(self): | ||
ser = Series([0.0, .1, .5, .9, 1.0]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this raise a
KeyError
not aValueError
?