Skip to content

REF: share code for scalar validation in datetimelike array methods #34076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 12, 2020
114 changes: 59 additions & 55 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timedelta
import operator
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast
from typing import Any, Callable, Sequence, Tuple, Type, TypeVar, Union, cast
import warnings

import numpy as np
Expand Down Expand Up @@ -437,6 +437,7 @@ class DatetimeLikeArrayMixin(
"""

_is_recognized_dtype: Callable[[DtypeObj], bool]
_recognized_scalars: Tuple[Type, ...]

# ------------------------------------------------------------------
# NDArrayBackedExtensionArray compat
Expand Down Expand Up @@ -718,16 +719,14 @@ def _validate_fill_value(self, fill_value):
------
ValueError
"""
if is_valid_nat_for_dtype(fill_value, self.dtype):
fill_value = NaT
elif isinstance(fill_value, self._recognized_scalars):
fill_value = self._scalar_type(fill_value)
else:
raise ValueError(
f"'fill_value' should be a {self._scalar_type}. "
f"Got '{str(fill_value)}'."
)

msg = (
f"'fill_value' should be a {self._scalar_type}. "
f"Got '{str(fill_value)}'."
)
try:
fill_value = self._validate_scalar(fill_value, msg)
except TypeError as err:
raise ValueError(msg) from err
return self._unbox(fill_value)

def _validate_shift_value(self, fill_value):
Expand Down Expand Up @@ -757,6 +756,41 @@ def _validate_shift_value(self, fill_value):

return self._unbox(fill_value)

def _validate_scalar(self, value, msg: str, cast_str: bool = False):
"""
Validate that the input value can be cast to our scalar_type.

Parameters
----------
value : object
msg : str
Message to raise in TypeError on invalid input.
cast_str : bool, default False
Whether to try to parse string input to scalar_type.

Returns
-------
self._scalar_type or NaT
"""
if cast_str and isinstance(value, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a doc-string here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

# NB: Careful about tzawareness
try:
value = self._scalar_from_string(value)
except ValueError as err:
raise TypeError(msg) from err

elif is_valid_nat_for_dtype(value, self.dtype):
# GH#18295
value = NaT

elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value) # type: ignore

else:
raise TypeError(msg)

return value

def _validate_listlike(
self, value, opname: str, cast_str: bool = False, allow_object: bool = False,
):
Expand Down Expand Up @@ -795,72 +829,42 @@ def _validate_listlike(
return value

def _validate_searchsorted_value(self, value):
if isinstance(value, str):
try:
value = self._scalar_from_string(value)
except ValueError as err:
raise TypeError(
"searchsorted requires compatible dtype or scalar"
) from err

elif is_valid_nat_for_dtype(value, self.dtype):
value = NaT

elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)

elif not is_list_like(value):
raise TypeError(f"Unexpected type for 'value': {type(value)}")

msg = "searchsorted requires compatible dtype or scalar"
if not is_list_like(value):
value = self._validate_scalar(value, msg, cast_str=True)
else:
# TODO: cast_str? we accept it for scalar
value = self._validate_listlike(value, "searchsorted")

return self._unbox(value)

def _validate_setitem_value(self, value):

msg = (
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
f"or array of those. Got '{type(value).__name__}' instead."
)
if is_list_like(value):
value = self._validate_listlike(value, "setitem", cast_str=True)

elif isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)
elif is_valid_nat_for_dtype(value, self.dtype):
value = NaT
else:
msg = (
f"'value' should be a '{self._scalar_type.__name__}', 'NaT', "
f"or array of those. Got '{type(value).__name__}' instead."
)
raise TypeError(msg)
Copy link
Member

@gfyoung gfyoung May 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This message for example is a nicer, user-friendly one that would be good to keep.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok @gfyoung ?

# TODO: cast_str for consistency?
value = self._validate_scalar(value, msg, cast_str=False)

self._check_compatible_with(value, setitem=True)
return self._unbox(value)

def _validate_insert_value(self, value):
if isinstance(value, self._recognized_scalars):
value = self._scalar_type(value)
elif is_valid_nat_for_dtype(value, self.dtype):
# GH#18295
value = NaT
else:
raise TypeError(
f"cannot insert {type(self).__name__} with incompatible label"
)
msg = f"cannot insert {type(self).__name__} with incompatible label"
value = self._validate_scalar(value, msg, cast_str=False)

self._check_compatible_with(value, setitem=True)
# TODO: if we dont have compat, should we raise or astype(object)?
# PeriodIndex does astype(object)
return value

def _validate_where_value(self, other):
if is_valid_nat_for_dtype(other, self.dtype):
other = NaT
elif isinstance(other, self._recognized_scalars):
other = self._scalar_type(other)
elif not is_list_like(other):
raise TypeError(f"Where requires matching dtype, not {type(other)}")

msg = f"Where requires matching dtype, not {type(other)}"
if not is_list_like(other):
other = self._validate_scalar(other, msg)
else:
other = self._validate_listlike(other, "where")
self._check_compatible_with(other, setitem=True)
Expand Down
22 changes: 6 additions & 16 deletions pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" implement the TimedeltaIndex """

from pandas._libs import NaT, Timedelta, index as libindex, lib
from pandas._libs import Timedelta, index as libindex, lib
from pandas._typing import DtypeObj, Label
from pandas.util._decorators import doc

Expand All @@ -13,7 +13,6 @@
is_timedelta64_ns_dtype,
pandas_dtype,
)
from pandas.core.dtypes.missing import is_valid_nat_for_dtype

from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays.timedeltas import TimedeltaArray
Expand Down Expand Up @@ -214,20 +213,11 @@ def get_loc(self, key, method=None, tolerance=None):
if not is_scalar(key):
raise InvalidIndexError(key)

if is_valid_nat_for_dtype(key, self.dtype):
key = NaT

elif isinstance(key, str):
try:
key = Timedelta(key)
except ValueError as err:
raise KeyError(key) from err

elif isinstance(key, self._data._recognized_scalars) or key is NaT:
key = Timedelta(key)

else:
raise KeyError(key)
msg = str(key)
try:
key = self._data._validate_scalar(key, msg, cast_str=True)
except TypeError as err:
raise KeyError(key) from err

return Index.get_loc(self, key, method, tolerance)

Expand Down