Skip to content

Commit a02183a

Browse files
committed
Use enum instead of integers
1 parent 7afa6b5 commit a02183a

File tree

1 file changed

+25
-29
lines changed

1 file changed

+25
-29
lines changed

pandas/_libs/window.pyx

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,27 +1357,21 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp,
13571357
return output
13581358

13591359

1360-
def _get_interpolation_id(str interpolation):
1361-
"""
1362-
Converts string to interpolation id
1360+
cdef enum InterpolationType:
1361+
LINEAR,
1362+
LOWER,
1363+
HIGHER,
1364+
NEAREST,
1365+
MIDPOINT
13631366

1364-
Parameters
1365-
----------
1366-
interpolation: 'linear', 'lower', 'higher', 'nearest', 'midpoint'
1367-
"""
1368-
if interpolation == 'linear':
1369-
return 0
1370-
elif interpolation == 'lower':
1371-
return 1
1372-
elif interpolation == 'higher':
1373-
return 2
1374-
elif interpolation == 'nearest':
1375-
return 3
1376-
elif interpolation == 'midpoint':
1377-
return 4
1378-
else:
1379-
raise ValueError("Interpolation {} is not supported"
1380-
.format(interpolation))
1367+
1368+
interpolation_types = {
1369+
'linear': LINEAR,
1370+
'lower': LOWER,
1371+
'higher': HIGHER,
1372+
'nearest': NEAREST,
1373+
'midpoint': MIDPOINT,
1374+
}
13811375

13821376

13831377
def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
@@ -1395,14 +1389,16 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
13951389
ndarray[int64_t] start, end
13961390
ndarray[double_t] output
13971391
double vlow, vhigh
1398-
int interpolation_id
1392+
InterpolationType interpolation_type
13991393

14001394
if quantile <= 0.0 or quantile >= 1.0:
14011395
raise ValueError("quantile value {0} not in [0, 1]".format(quantile))
14021396

1403-
# interpolation_id is needed to avoid string comparisons inside the loop
1404-
# I tried to use callback but it resulted in worse performance
1405-
interpolation_id = _get_interpolation_id(interpolation)
1397+
try:
1398+
interpolation_type = interpolation_types[interpolation]
1399+
except ValueError:
1400+
raise ValueError("Interpolation {} is not supported"
1401+
.format(interpolation))
14061402

14071403
# we use the Fixed/Variable Indexer here as the
14081404
# actual skiplist ops outweigh any window computation costs
@@ -1449,21 +1445,21 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
14491445
idx_with_fraction = quantile * <double> (nobs - 1)
14501446
idx = int(idx_with_fraction)
14511447

1452-
if interpolation_id == 0: # linear
1448+
if interpolation_type == LINEAR:
14531449
vlow = skiplist.get(idx)
14541450
vhigh = skiplist.get(idx + 1)
14551451
output[i] = ((vlow + (vhigh - vlow) *
14561452
(idx_with_fraction - idx)))
1457-
elif interpolation_id == 1: # lower
1453+
elif interpolation_type == LOWER:
14581454
output[i] = skiplist.get(idx)
1459-
elif interpolation_id == 2: # higher
1455+
elif interpolation_type == HIGHER:
14601456
output[i] = skiplist.get(idx + 1)
1461-
elif interpolation_id == 3: # nearest
1457+
elif interpolation_type == NEAREST:
14621458
if idx_with_fraction - idx < 0.5:
14631459
output[i] = skiplist.get(idx)
14641460
else:
14651461
output[i] = skiplist.get(idx + 1)
1466-
elif interpolation_id == 4: # midpoint
1462+
elif interpolation_type == MIDPOINT:
14671463
vlow = skiplist.get(idx)
14681464
vhigh = skiplist.get(idx + 1)
14691465
output[i] = <double> (vlow + vhigh) / 2

0 commit comments

Comments
 (0)