@@ -1357,27 +1357,21 @@ cdef _roll_min_max(ndarray[numeric] input, int64_t win, int64_t minp,
1357
1357
return output
1358
1358
1359
1359
1360
- def _get_interpolation_id (str interpolation ):
1361
- """
1362
- Converts string to interpolation id
1360
+ cdef enum InterpolationType:
1361
+ LINEAR,
1362
+ LOWER,
1363
+ HIGHER,
1364
+ NEAREST,
1365
+ MIDPOINT
1363
1366
1364
- Parameters
1365
- ----------
1366
- interpolation: 'linear', 'lower', 'higher', 'nearest', 'midpoint'
1367
- """
1368
- if interpolation == ' linear' :
1369
- return 0
1370
- elif interpolation == ' lower' :
1371
- return 1
1372
- elif interpolation == ' higher' :
1373
- return 2
1374
- elif interpolation == ' nearest' :
1375
- return 3
1376
- elif interpolation == ' midpoint' :
1377
- return 4
1378
- else :
1379
- raise ValueError (" Interpolation {} is not supported"
1380
- .format(interpolation))
1367
+
1368
+ interpolation_types = {
1369
+ ' linear' : LINEAR,
1370
+ ' lower' : LOWER,
1371
+ ' higher' : HIGHER,
1372
+ ' nearest' : NEAREST,
1373
+ ' midpoint' : MIDPOINT,
1374
+ }
1381
1375
1382
1376
1383
1377
def roll_quantile (ndarray[float64_t , cast = True ] input , int64_t win ,
@@ -1395,14 +1389,16 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
1395
1389
ndarray[int64_t] start, end
1396
1390
ndarray[double_t] output
1397
1391
double vlow, vhigh
1398
- int interpolation_id
1392
+ InterpolationType interpolation_type
1399
1393
1400
1394
if quantile <= 0.0 or quantile >= 1.0 :
1401
1395
raise ValueError (" quantile value {0} not in [0, 1]" .format(quantile))
1402
1396
1403
- # interpolation_id is needed to avoid string comparisons inside the loop
1404
- # I tried to use callback but it resulted in worse performance
1405
- interpolation_id = _get_interpolation_id(interpolation)
1397
+ try :
1398
+ interpolation_type = interpolation_types[interpolation]
1399
+ except ValueError :
1400
+ raise ValueError (" Interpolation {} is not supported"
1401
+ .format(interpolation))
1406
1402
1407
1403
# we use the Fixed/Variable Indexer here as the
1408
1404
# actual skiplist ops outweigh any window computation costs
@@ -1449,21 +1445,21 @@ def roll_quantile(ndarray[float64_t, cast=True] input, int64_t win,
1449
1445
idx_with_fraction = quantile * < double > (nobs - 1 )
1450
1446
idx = int (idx_with_fraction)
1451
1447
1452
- if interpolation_id == 0 : # linear
1448
+ if interpolation_type == LINEAR:
1453
1449
vlow = skiplist.get(idx)
1454
1450
vhigh = skiplist.get(idx + 1 )
1455
1451
output[i] = ((vlow + (vhigh - vlow) *
1456
1452
(idx_with_fraction - idx)))
1457
- elif interpolation_id == 1 : # lower
1453
+ elif interpolation_type == LOWER:
1458
1454
output[i] = skiplist.get(idx)
1459
- elif interpolation_id == 2 : # higher
1455
+ elif interpolation_type == HIGHER:
1460
1456
output[i] = skiplist.get(idx + 1 )
1461
- elif interpolation_id == 3 : # nearest
1457
+ elif interpolation_type == NEAREST:
1462
1458
if idx_with_fraction - idx < 0.5 :
1463
1459
output[i] = skiplist.get(idx)
1464
1460
else :
1465
1461
output[i] = skiplist.get(idx + 1 )
1466
- elif interpolation_id == 4 : # midpoint
1462
+ elif interpolation_type == MIDPOINT:
1467
1463
vlow = skiplist.get(idx)
1468
1464
vhigh = skiplist.get(idx + 1 )
1469
1465
output[i] = < double > (vlow + vhigh) / 2
0 commit comments