Skip to content

Commit 6d5c77d

Browse files
committed
Add test with nan values. Do not interpolate if index is an integer
1 parent 4f80369 commit 6d5c77d

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

pandas/_libs/window.pyx

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,12 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
14431443
output[i] = skiplist.get(0)
14441444
else:
14451445
idx_with_fraction = quantile * (nobs - 1)
1446-
idx = int(idx_with_fraction)
1446+
idx = <int> idx_with_fraction
1447+
1448+
if idx_with_fraction == idx:
1449+
# no need to interpolate
1450+
output[i] = skiplist.get(idx)
1451+
continue
14471452

14481453
if interpolation_type == LINEAR:
14491454
vlow = skiplist.get(idx)
@@ -1455,7 +1460,16 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
14551460
elif interpolation_type == HIGHER:
14561461
output[i] = skiplist.get(idx + 1)
14571462
elif interpolation_type == NEAREST:
1458-
output[i] = skiplist.get(round(idx_with_fraction))
1463+
# the same behaviour as round()
1464+
if idx_with_fraction - idx == 0.5:
1465+
if idx % 2 == 0:
1466+
output[i] = skiplist.get(idx)
1467+
else:
1468+
output[i] = skiplist.get(idx + 1)
1469+
elif idx_with_fraction - idx < 0.5:
1470+
output[i] = skiplist.get(idx)
1471+
else:
1472+
output[i] = skiplist.get(idx + 1)
14591473
elif interpolation_type == MIDPOINT:
14601474
vlow = skiplist.get(idx)
14611475
vhigh = skiplist.get(idx + 1)

pandas/core/window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1918,7 +1918,7 @@ def kurt(self, **kwargs):
19181918
@Substitution(name='expanding')
19191919
@Appender(_doc_template)
19201920
@Appender(_shared_docs['quantile'])
1921-
def quantile(self, quantile, interpolation='linear', **kwargs,):
1921+
def quantile(self, quantile, interpolation='linear', **kwargs):
19221922
return super(Expanding, self).quantile(quantile=quantile,
19231923
interpolation=interpolation,
19241924
**kwargs)

pandas/tests/test_window.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,31 +1177,37 @@ def test_rolling_quantile_series(self):
11771177
tm.assert_almost_equal(q1, q2)
11781178

11791179
@pytest.mark.parametrize('quantile', [0.0, 0.1, 0.45, 0.5, 1])
1180-
@pytest.mark.parametrize('na_probability', [0.0, 0.3])
11811180
@pytest.mark.parametrize('interpolation', ['linear', 'lower', 'higher',
11821181
'nearest', 'midpoint'])
1182+
@pytest.mark.parametrize('data', [[1., 2., 3., 4., 5., 6., 7.],
1183+
[8., 1., 3., 4., 5., 2., 6., 7.],
1184+
[0., np.nan, 0.2, np.nan, 0.4],
1185+
[np.nan, np.nan, np.nan, np.nan],
1186+
[np.nan, 0.1, np.nan, 0.3, 0.4, 0.5],
1187+
[0.5], [np.nan, 0.7, 0.5]])
11831188
def test_rolling_quantile_interpolation_options(self, quantile,
1184-
na_probability,
1185-
interpolation):
1189+
interpolation, data):
11861190
# Tests that rolling window's quantile behavior is analogous to
11871191
# Series' quantile for each interpolation option
1188-
size = 100
1189-
s = Series(np.random.rand(size))
1190-
1191-
# set NaN values
1192-
na_count = 0
1193-
na_total = int(size * na_probability)
1194-
while na_count < na_total:
1195-
index = np.random.randint(0, size)
1196-
if not np.isnan(s[index]):
1197-
s[index] = np.NaN
1198-
na_count += 1
1192+
s = Series(data)
11991193

12001194
q1 = s.quantile(quantile, interpolation)
1201-
q2 = s.rolling(size, min_periods=1).quantile(
1195+
q2 = s.rolling(len(data), min_periods=1).quantile(
12021196
quantile, interpolation).iloc[-1]
12031197

1204-
tm.assert_almost_equal(q1, q2)
1198+
if np.isnan(q1):
1199+
assert np.isnan(q2)
1200+
else:
1201+
assert round(q1, 15) == round(q2, 15)
1202+
1203+
def test_invalid_quantile_value(self):
1204+
data = np.arange(5)
1205+
s = Series(data)
1206+
1207+
with pytest.raises(ValueError, match="Interpolation invalid"
1208+
" is not supported"):
1209+
s.rolling(len(data), min_periods=1).quantile(
1210+
0.5, interpolation='invalid')
12051211

12061212
def test_rolling_quantile_param(self):
12071213
ser = Series([0.0, .1, .5, .9, 1.0])

0 commit comments

Comments
 (0)