Skip to content

Commit c832ccb

Browse files
authored
PERF: faster searchsorted in tzconversion (#46329)
1 parent a72fa1b commit c832ccb

File tree

3 files changed

+51
-24
lines changed

3 files changed

+51
-24
lines changed

pandas/_libs/tslibs/conversion.pyx

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ from pandas._libs.tslibs.nattype cimport (
6969
checknull_with_nat,
7070
)
7171
from pandas._libs.tslibs.tzconversion cimport (
72+
bisect_right_i8,
7273
tz_convert_utc_to_tzlocal,
7374
tz_localize_to_utc_single,
7475
)
@@ -536,6 +537,7 @@ cdef _TSObject _create_tsobject_tz_using_offset(npy_datetimestruct dts,
536537
int64_t value # numpy dt64
537538
datetime dt
538539
ndarray[int64_t] trans
540+
int64_t* tdata
539541
int64_t[::1] deltas
540542

541543
value = dtstruct_to_dt64(&dts)
@@ -556,7 +558,8 @@ cdef _TSObject _create_tsobject_tz_using_offset(npy_datetimestruct dts,
556558
trans, deltas, typ = get_dst_info(tz)
557559

558560
if typ == 'dateutil':
559-
pos = trans.searchsorted(obj.value, side='right') - 1
561+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
562+
pos = bisect_right_i8(tdata, obj.value, trans.shape[0]) - 1
560563
obj.fold = _infer_tsobject_fold(obj, trans, deltas, pos)
561564

562565
# Keep the converter same as PyDateTime's
@@ -708,7 +711,8 @@ cdef inline void _localize_tso(_TSObject obj, tzinfo tz):
708711
ndarray[int64_t] trans
709712
int64_t[::1] deltas
710713
int64_t local_val
711-
Py_ssize_t pos
714+
int64_t* tdata
715+
Py_ssize_t pos, ntrans
712716
str typ
713717

714718
assert obj.tzinfo is None
@@ -723,17 +727,20 @@ cdef inline void _localize_tso(_TSObject obj, tzinfo tz):
723727
else:
724728
# Adjust datetime64 timestamp, recompute datetimestruct
725729
trans, deltas, typ = get_dst_info(tz)
730+
ntrans = trans.shape[0]
726731

727732
if typ == "pytz":
728733
# i.e. treat_tz_as_pytz(tz)
729-
pos = trans.searchsorted(obj.value, side="right") - 1
734+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
735+
pos = bisect_right_i8(tdata, obj.value, ntrans) - 1
730736
local_val = obj.value + deltas[pos]
731737

732738
# find right representation of dst etc in pytz timezone
733739
tz = tz._tzinfos[tz._transition_info[pos]]
734740
elif typ == "dateutil":
735741
# i.e. treat_tz_as_dateutil(tz)
736-
pos = trans.searchsorted(obj.value, side="right") - 1
742+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
743+
pos = bisect_right_i8(tdata, obj.value, ntrans) - 1
737744
local_val = obj.value + deltas[pos]
738745

739746
# dateutil supports fold, so we infer fold from value

pandas/_libs/tslibs/tzconversion.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ cpdef int64_t tz_convert_from_utc_single(int64_t val, tzinfo tz)
99
cdef int64_t tz_localize_to_utc_single(
1010
int64_t val, tzinfo tz, object ambiguous=*, object nonexistent=*
1111
) except? -1
12+
13+
cdef Py_ssize_t bisect_right_i8(int64_t *data, int64_t val, Py_ssize_t n)

pandas/_libs/tslibs/tzconversion.pyx

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,8 @@ timedelta-like}
116116
"""
117117
cdef:
118118
int64_t[::1] deltas
119-
int64_t[:] idx_shifted, idx_shifted_left, idx_shifted_right
120119
ndarray[uint8_t, cast=True] ambiguous_array, both_nat, both_eq
121-
Py_ssize_t i, idx, pos, ntrans, n = vals.shape[0]
120+
Py_ssize_t i, isl, isr, idx, pos, ntrans, n = vals.shape[0]
122121
Py_ssize_t delta_idx_offset, delta_idx, pos_left, pos_right
123122
int64_t *tdata
124123
int64_t v, left, right, val, v_left, v_right, new_local, remaining_mins
@@ -194,21 +193,28 @@ timedelta-like}
194193
result_a[:] = NPY_NAT
195194
result_b[:] = NPY_NAT
196195

197-
idx_shifted_left = (np.maximum(0, trans.searchsorted(
198-
vals - DAY_NANOS, side='right') - 1)).astype(np.int64)
199-
200-
idx_shifted_right = (np.maximum(0, trans.searchsorted(
201-
vals + DAY_NANOS, side='right') - 1)).astype(np.int64)
202-
203196
for i in range(n):
204197
val = vals[i]
205-
v_left = val - deltas[idx_shifted_left[i]]
198+
if val == NPY_NAT:
199+
continue
200+
201+
# TODO: be careful of overflow in val-DAY_NANOS
202+
isl = bisect_right_i8(tdata, val - DAY_NANOS, ntrans) - 1
203+
if isl < 0:
204+
isl = 0
205+
206+
v_left = val - deltas[isl]
206207
pos_left = bisect_right_i8(tdata, v_left, ntrans) - 1
207208
# timestamp falls to the left side of the DST transition
208209
if v_left + deltas[pos_left] == val:
209210
result_a[i] = v_left
210211

211-
v_right = val - deltas[idx_shifted_right[i]]
212+
# TODO: be careful of overflow in val+DAY_NANOS
213+
isr = bisect_right_i8(tdata, val + DAY_NANOS, ntrans) - 1
214+
if isr < 0:
215+
isr = 0
216+
217+
v_right = val - deltas[isr]
212218
pos_right = bisect_right_i8(tdata, v_right, ntrans) - 1
213219
# timestamp falls to the right side of the DST transition
214220
if v_right + deltas[pos_right] == val:
@@ -309,7 +315,9 @@ timedelta-like}
309315
# Subtract 1 since the beginning hour is _inclusive_ of
310316
# nonexistent times
311317
new_local = val - remaining_mins - 1
312-
delta_idx = trans.searchsorted(new_local, side='right')
318+
319+
delta_idx = bisect_right_i8(tdata, new_local, ntrans)
320+
313321
# Shift the delta_idx by if the UTC offset of
314322
# the target tz is greater than 0 and we're moving forward
315323
# or vice versa
@@ -333,17 +341,22 @@ timedelta-like}
333341

334342
cdef inline Py_ssize_t bisect_right_i8(int64_t *data,
335343
int64_t val, Py_ssize_t n):
344+
# Caller is responsible for checking n > 0
345+
# This looks very similar to local_search_right in the ndarray.searchsorted
346+
# implementation.
336347
cdef:
337348
Py_ssize_t pivot, left = 0, right = n
338349

339-
assert n >= 1
340-
341350
# edge cases
342351
if val > data[n - 1]:
343352
return n
344353

345-
if val < data[0]:
346-
return 0
354+
# Caller is responsible for ensuring 'val >= data[0]'. This is
355+
# ensured by the fact that 'data' comes from get_dst_info where data[0]
356+
# is *always* NPY_NAT+1. If that ever changes, we will need to restore
357+
# the following disabled check.
358+
# if val < data[0]:
359+
# return 0
347360

348361
while left < right:
349362
pivot = left + (right - left) // 2
@@ -403,6 +416,7 @@ cpdef int64_t tz_convert_from_utc_single(int64_t val, tzinfo tz):
403416
int64_t delta
404417
int64_t[::1] deltas
405418
ndarray[int64_t, ndim=1] trans
419+
int64_t* tdata
406420
intp_t pos
407421

408422
if val == NPY_NAT:
@@ -418,7 +432,8 @@ cpdef int64_t tz_convert_from_utc_single(int64_t val, tzinfo tz):
418432
return val + delta
419433
else:
420434
trans, deltas, _ = get_dst_info(tz)
421-
pos = trans.searchsorted(val, side="right") - 1
435+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
436+
pos = bisect_right_i8(tdata, val, trans.shape[0]) - 1
422437
return val + deltas[pos]
423438

424439

@@ -462,10 +477,11 @@ cdef const int64_t[:] _tz_convert_from_utc(const int64_t[:] vals, tzinfo tz):
462477
"""
463478
cdef:
464479
int64_t[::1] converted, deltas
465-
Py_ssize_t i, n = vals.shape[0]
480+
Py_ssize_t i, ntrans = -1, n = vals.shape[0]
466481
int64_t val, delta = 0 # avoid not-initialized-warning
467-
intp_t[:] pos
482+
intp_t pos
468483
ndarray[int64_t] trans
484+
int64_t* tdata = NULL
469485
str typ
470486
bint use_tzlocal = False, use_fixed = False, use_utc = True
471487

@@ -479,13 +495,14 @@ cdef const int64_t[:] _tz_convert_from_utc(const int64_t[:] vals, tzinfo tz):
479495
use_tzlocal = True
480496
else:
481497
trans, deltas, typ = get_dst_info(tz)
498+
ntrans = trans.shape[0]
482499

483500
if typ not in ["pytz", "dateutil"]:
484501
# FixedOffset, we know len(deltas) == 1
485502
delta = deltas[0]
486503
use_fixed = True
487504
else:
488-
pos = trans.searchsorted(vals, side="right") - 1
505+
tdata = <int64_t*>cnp.PyArray_DATA(trans)
489506

490507
converted = np.empty(n, dtype=np.int64)
491508

@@ -502,7 +519,8 @@ cdef const int64_t[:] _tz_convert_from_utc(const int64_t[:] vals, tzinfo tz):
502519
elif use_fixed:
503520
converted[i] = val + delta
504521
else:
505-
converted[i] = val + deltas[pos[i]]
522+
pos = bisect_right_i8(tdata, val, ntrans) - 1
523+
converted[i] = val + deltas[pos]
506524

507525
return converted
508526

0 commit comments

Comments
 (0)