Skip to content

Commit f31a448

Browse files
committed
Add interpolation options to moving quantile
1 parent 41db527 commit f31a448

File tree

3 files changed

+86
-16
lines changed

3 files changed

+86
-16
lines changed

pandas/_libs/window.pyx

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,26 +1356,53 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp,
13561356
# print("output: {0}".format(output))
13571357
return output
13581358

1359+
def _get_interpolation_id(str interpolation):
1360+
"""
1361+
Converts string to interpolation id
1362+
1363+
Parameters
1364+
----------
1365+
interpolation: 'linear', 'lower', 'higher', 'nearest', 'midpoint'
1366+
"""
1367+
if interpolation == 'linear':
1368+
return 0
1369+
elif interpolation == 'lower':
1370+
return 1
1371+
elif interpolation == 'higher':
1372+
return 2
1373+
elif interpolation == 'nearest':
1374+
return 3
1375+
elif interpolation == 'midpoint':
1376+
return 4
1377+
else:
1378+
raise ValueError("Interpolation {} is not supported"
1379+
.format(interpolation))
1380+
13591381

13601382
def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
13611383
int64_t minp, object index, object closed,
1362-
double quantile):
1384+
double quantile, str interpolation):
13631385
"""
13641386
O(N log(window)) implementation using skip list
13651387
"""
13661388
cdef:
1367-
double val, prev, midpoint
1389+
double val, prev, midpoint, idx_with_fraction
13681390
IndexableSkiplist skiplist
13691391
int64_t nobs = 0, i, j, s, e, N
13701392
Py_ssize_t idx
13711393
bint is_variable
13721394
ndarray[int64_t] start, end
13731395
ndarray[double_t] output
13741396
double vlow, vhigh
1397+
int interpolation_id
13751398

13761399
if quantile <= 0.0 or quantile >= 1.0:
13771400
raise ValueError("quantile value {0} not in [0, 1]".format(quantile))
13781401

1402+
# interpolation_id is needed to avoid string comparisons inside the loop
1403+
# I tried to use callback but it resulted in worse performance
1404+
interpolation_id = _get_interpolation_id(interpolation)
1405+
13791406
# we use the Fixed/Variable Indexer here as the
13801407
# actual skiplist ops outweigh any window computation costs
13811408
start, end, N, win, minp, is_variable = get_window_indexer(
@@ -1414,18 +1441,31 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
14141441
skiplist.insert(val)
14151442

14161443
if nobs >= minp:
1417-
idx = int(quantile * <double>(nobs - 1))
1418-
1419-
# Single value in skip list
14201444
if nobs == 1:
1445+
# Single value in skip list
14211446
output[i] = skiplist.get(0)
1422-
1423-
# Interpolated quantile
14241447
else:
1425-
vlow = skiplist.get(idx)
1426-
vhigh = skiplist.get(idx + 1)
1427-
output[i] = ((vlow + (vhigh - vlow) *
1428-
(quantile * (nobs - 1) - idx)))
1448+
idx_with_fraction = quantile * <double> (nobs - 1)
1449+
idx = int(idx_with_fraction)
1450+
1451+
if interpolation_id == 0: # linear
1452+
vlow = skiplist.get(idx)
1453+
vhigh = skiplist.get(idx + 1)
1454+
output[i] = ((vlow + (vhigh - vlow) *
1455+
(idx_with_fraction - idx)))
1456+
elif interpolation_id == 1: # lower
1457+
output[i] = skiplist.get(idx)
1458+
elif interpolation_id == 2: # higher
1459+
output[i] = skiplist.get(idx + 1)
1460+
elif interpolation_id == 3: # nearest
1461+
if idx_with_fraction - idx < 0.5:
1462+
output[i] = skiplist.get(idx)
1463+
else:
1464+
output[i] = skiplist.get(idx + 1)
1465+
elif interpolation_id == 4: # midpoint
1466+
vlow = skiplist.get(idx)
1467+
vhigh = skiplist.get(idx + 1)
1468+
output[i] = <double> (vlow + vhigh) / 2
14291469
else:
14301470
output[i] = NaN
14311471

pandas/core/window.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,9 +1276,21 @@ def kurt(self, **kwargs):
12761276
Parameters
12771277
----------
12781278
quantile : float
1279-
0 <= quantile <= 1""")
1279+
0 <= quantile <= 1
1280+
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
1281+
.. versionadded:: 0.23.0
12801282
1281-
def quantile(self, quantile, **kwargs):
1283+
This optional parameter specifies the interpolation method to use,
1284+
when the desired quantile lies between two data points `i` and `j`:
1285+
1286+
* linear: `i + (j - i) * fraction`, where `fraction` is the
1287+
fractional part of the index surrounded by `i` and `j`.
1288+
* lower: `i`.
1289+
* higher: `j`.
1290+
* nearest: `i` or `j` whichever is nearest.
1291+
* midpoint: (`i` + `j`) / 2.""")
1292+
1293+
def quantile(self, quantile, interpolation='linear', **kwargs):
12821294
window = self._get_window()
12831295
index, indexi = self._get_index()
12841296

@@ -1292,7 +1304,8 @@ def f(arg, *args, **kwargs):
12921304
self.closed)
12931305
else:
12941306
return _window.roll_quantile(arg, window, minp, indexi,
1295-
self.closed, quantile)
1307+
self.closed, quantile,
1308+
interpolation)
12961309

12971310
return self._apply(f, 'quantile', quantile=quantile,
12981311
**kwargs)
@@ -1613,8 +1626,10 @@ def kurt(self, **kwargs):
16131626
@Substitution(name='rolling')
16141627
@Appender(_doc_template)
16151628
@Appender(_shared_docs['quantile'])
1616-
def quantile(self, quantile, **kwargs):
1617-
return super(Rolling, self).quantile(quantile=quantile, **kwargs)
1629+
def quantile(self, quantile, interpolation='linear', **kwargs): # here
1630+
return super(Rolling, self).quantile(quantile=quantile,
1631+
interpolation=interpolation,
1632+
**kwargs)
16181633

16191634
@Substitution(name='rolling')
16201635
@Appender(_doc_template)

pandas/tests/test_window.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,22 @@ def test_rolling_quantile_series(self):
11731173
s = Series(arr)
11741174
q1 = s.quantile(0.1)
11751175
q2 = s.rolling(100).quantile(0.1).iloc[-1]
1176+
tm.assert_almost_equal(q1, q2)
1177+
1178+
q1 = s.quantile(0.1, interpolation='lower')
1179+
q2 = s.rolling(100).quantile(0.1, interpolation='lower').iloc[-1]
1180+
tm.assert_almost_equal(q1, q2)
1181+
1182+
q1 = s.quantile(0.1, interpolation='higher')
1183+
q2 = s.rolling(100).quantile(0.1, interpolation='higher').iloc[-1]
1184+
tm.assert_almost_equal(q1, q2)
1185+
1186+
q1 = s.quantile(0.1, interpolation='nearest')
1187+
q2 = s.rolling(100).quantile(0.1, interpolation='nearest').iloc[-1]
1188+
tm.assert_almost_equal(q1, q2)
11761189

1190+
q1 = s.quantile(0.1, interpolation='midpoint')
1191+
q2 = s.rolling(100).quantile(0.1, interpolation='midpoint').iloc[-1]
11771192
tm.assert_almost_equal(q1, q2)
11781193

11791194
def test_rolling_quantile_param(self):

0 commit comments

Comments
 (0)