Skip to content

Commit b47362d

Browse files
authored
REF: share code for scalar validation in datetimelike array methods (#34076)
1 parent 7c3653a commit b47362d

File tree

2 files changed

+65
-71
lines changed

2 files changed

+65
-71
lines changed

pandas/core/arrays/datetimelike.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime, timedelta
22
import operator
3-
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast
3+
from typing import Any, Callable, Sequence, Tuple, Type, TypeVar, Union, cast
44
import warnings
55

66
import numpy as np
@@ -437,6 +437,7 @@ class DatetimeLikeArrayMixin(
437437
"""
438438

439439
_is_recognized_dtype: Callable[[DtypeObj], bool]
440+
_recognized_scalars: Tuple[Type, ...]
440441

441442
# ------------------------------------------------------------------
442443
# NDArrayBackedExtensionArray compat
@@ -718,16 +719,14 @@ def _validate_fill_value(self, fill_value):
718719
------
719720
ValueError
720721
"""
721-
if is_valid_nat_for_dtype(fill_value, self.dtype):
722-
fill_value = NaT
723-
elif isinstance(fill_value, self._recognized_scalars):
724-
fill_value = self._scalar_type(fill_value)
725-
else:
726-
raise ValueError(
727-
f"'fill_value' should be a {self._scalar_type}. "
728-
f"Got '{str(fill_value)}'."
729-
)
730-
722+
msg = (
723+
f"'fill_value' should be a {self._scalar_type}. "
724+
f"Got '{str(fill_value)}'."
725+
)
726+
try:
727+
fill_value = self._validate_scalar(fill_value, msg)
728+
except TypeError as err:
729+
raise ValueError(msg) from err
731730
return self._unbox(fill_value)
732731

733732
def _validate_shift_value(self, fill_value):
@@ -757,6 +756,41 @@ def _validate_shift_value(self, fill_value):
757756

758757
return self._unbox(fill_value)
759758

759+
def _validate_scalar(self, value, msg: str, cast_str: bool = False):
760+
"""
761+
Validate that the input value can be cast to our scalar_type.
762+
763+
Parameters
764+
----------
765+
value : object
766+
msg : str
767+
Message to raise in TypeError on invalid input.
768+
cast_str : bool, default False
769+
Whether to try to parse string input to scalar_type.
770+
771+
Returns
772+
-------
773+
self._scalar_type or NaT
774+
"""
775+
if cast_str and isinstance(value, str):
776+
# NB: Careful about tzawareness
777+
try:
778+
value = self._scalar_from_string(value)
779+
except ValueError as err:
780+
raise TypeError(msg) from err
781+
782+
elif is_valid_nat_for_dtype(value, self.dtype):
783+
# GH#18295
784+
value = NaT
785+
786+
elif isinstance(value, self._recognized_scalars):
787+
value = self._scalar_type(value) # type: ignore
788+
789+
else:
790+
raise TypeError(msg)
791+
792+
return value
793+
760794
def _validate_listlike(
761795
self, value, opname: str, cast_str: bool = False, allow_object: bool = False,
762796
):
@@ -795,72 +829,42 @@ def _validate_listlike(
795829
return value
796830

797831
def _validate_searchsorted_value(self, value):
798-
if isinstance(value, str):
799-
try:
800-
value = self._scalar_from_string(value)
801-
except ValueError as err:
802-
raise TypeError(
803-
"searchsorted requires compatible dtype or scalar"
804-
) from err
805-
806-
elif is_valid_nat_for_dtype(value, self.dtype):
807-
value = NaT
808-
809-
elif isinstance(value, self._recognized_scalars):
810-
value = self._scalar_type(value)
811-
812-
elif not is_list_like(value):
813-
raise TypeError(f"Unexpected type for 'value': {type(value)}")
814-
832+
msg = "searchsorted requires compatible dtype or scalar"
833+
if not is_list_like(value):
834+
value = self._validate_scalar(value, msg, cast_str=True)
815835
else:
816836
# TODO: cast_str? we accept it for scalar
817837
value = self._validate_listlike(value, "searchsorted")
818838

819839
return self._unbox(value)
820840

821841
def _validate_setitem_value(self, value):
822-
842+
msg = (
843+
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
844+
f"or array of those. Got '{type(value).__name__}' instead."
845+
)
823846
if is_list_like(value):
824847
value = self._validate_listlike(value, "setitem", cast_str=True)
825-
826-
elif isinstance(value, self._recognized_scalars):
827-
value = self._scalar_type(value)
828-
elif is_valid_nat_for_dtype(value, self.dtype):
829-
value = NaT
830848
else:
831-
msg = (
832-
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
833-
f"or array of those. Got '{type(value).__name__}' instead."
834-
)
835-
raise TypeError(msg)
849+
# TODO: cast_str for consistency?
850+
value = self._validate_scalar(value, msg, cast_str=False)
836851

837852
self._check_compatible_with(value, setitem=True)
838853
return self._unbox(value)
839854

840855
def _validate_insert_value(self, value):
841-
if isinstance(value, self._recognized_scalars):
842-
value = self._scalar_type(value)
843-
elif is_valid_nat_for_dtype(value, self.dtype):
844-
# GH#18295
845-
value = NaT
846-
else:
847-
raise TypeError(
848-
f"cannot insert {type(self).__name__} with incompatible label"
849-
)
856+
msg = f"cannot insert {type(self).__name__} with incompatible label"
857+
value = self._validate_scalar(value, msg, cast_str=False)
850858

851859
self._check_compatible_with(value, setitem=True)
852860
# TODO: if we dont have compat, should we raise or astype(object)?
853861
# PeriodIndex does astype(object)
854862
return value
855863

856864
def _validate_where_value(self, other):
857-
if is_valid_nat_for_dtype(other, self.dtype):
858-
other = NaT
859-
elif isinstance(other, self._recognized_scalars):
860-
other = self._scalar_type(other)
861-
elif not is_list_like(other):
862-
raise TypeError(f"Where requires matching dtype, not {type(other)}")
863-
865+
msg = f"Where requires matching dtype, not {type(other)}"
866+
if not is_list_like(other):
867+
other = self._validate_scalar(other, msg)
864868
else:
865869
other = self._validate_listlike(other, "where")
866870
self._check_compatible_with(other, setitem=True)

pandas/core/indexes/timedeltas.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
""" implement the TimedeltaIndex """
22

3-
from pandas._libs import NaT, Timedelta, index as libindex, lib
3+
from pandas._libs import Timedelta, index as libindex, lib
44
from pandas._typing import DtypeObj, Label
55
from pandas.util._decorators import doc
66

@@ -13,7 +13,6 @@
1313
is_timedelta64_ns_dtype,
1414
pandas_dtype,
1515
)
16-
from pandas.core.dtypes.missing import is_valid_nat_for_dtype
1716

1817
from pandas.core.arrays import datetimelike as dtl
1918
from pandas.core.arrays.timedeltas import TimedeltaArray
@@ -214,20 +213,11 @@ def get_loc(self, key, method=None, tolerance=None):
214213
if not is_scalar(key):
215214
raise InvalidIndexError(key)
216215

217-
if is_valid_nat_for_dtype(key, self.dtype):
218-
key = NaT
219-
220-
elif isinstance(key, str):
221-
try:
222-
key = Timedelta(key)
223-
except ValueError as err:
224-
raise KeyError(key) from err
225-
226-
elif isinstance(key, self._data._recognized_scalars) or key is NaT:
227-
key = Timedelta(key)
228-
229-
else:
230-
raise KeyError(key)
216+
msg = str(key)
217+
try:
218+
key = self._data._validate_scalar(key, msg, cast_str=True)
219+
except TypeError as err:
220+
raise KeyError(key) from err
231221

232222
return Index.get_loc(self, key, method, tolerance)
233223

0 commit comments

Comments
 (0)