|
1 | 1 | from datetime import datetime, timedelta
|
2 | 2 | import operator
|
3 |
| -from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast |
| 3 | +from typing import Any, Callable, Sequence, Tuple, Type, TypeVar, Union, cast |
4 | 4 | import warnings
|
5 | 5 |
|
6 | 6 | import numpy as np
|
@@ -437,6 +437,7 @@ class DatetimeLikeArrayMixin(
|
437 | 437 | """
|
438 | 438 |
|
439 | 439 | _is_recognized_dtype: Callable[[DtypeObj], bool]
|
| 440 | + _recognized_scalars: Tuple[Type, ...] |
440 | 441 |
|
441 | 442 | # ------------------------------------------------------------------
|
442 | 443 | # NDArrayBackedExtensionArray compat
|
@@ -718,16 +719,14 @@ def _validate_fill_value(self, fill_value):
|
718 | 719 | ------
|
719 | 720 | ValueError
|
720 | 721 | """
|
721 |
| - if is_valid_nat_for_dtype(fill_value, self.dtype): |
722 |
| - fill_value = NaT |
723 |
| - elif isinstance(fill_value, self._recognized_scalars): |
724 |
| - fill_value = self._scalar_type(fill_value) |
725 |
| - else: |
726 |
| - raise ValueError( |
727 |
| - f"'fill_value' should be a {self._scalar_type}. " |
728 |
| - f"Got '{str(fill_value)}'." |
729 |
| - ) |
730 |
| - |
| 722 | + msg = ( |
| 723 | + f"'fill_value' should be a {self._scalar_type}. " |
| 724 | + f"Got '{str(fill_value)}'." |
| 725 | + ) |
| 726 | + try: |
| 727 | + fill_value = self._validate_scalar(fill_value, msg) |
| 728 | + except TypeError as err: |
| 729 | + raise ValueError(msg) from err |
731 | 730 | return self._unbox(fill_value)
|
732 | 731 |
|
733 | 732 | def _validate_shift_value(self, fill_value):
|
@@ -757,6 +756,41 @@ def _validate_shift_value(self, fill_value):
|
757 | 756 |
|
758 | 757 | return self._unbox(fill_value)
|
759 | 758 |
|
| 759 | + def _validate_scalar(self, value, msg: str, cast_str: bool = False): |
| 760 | + """ |
| 761 | + Validate that the input value can be cast to our scalar_type. |
| 762 | +
|
| 763 | + Parameters |
| 764 | + ---------- |
| 765 | + value : object |
| 766 | + msg : str |
| 767 | + Message to raise in TypeError on invalid input. |
| 768 | + cast_str : bool, default False |
| 769 | + Whether to try to parse string input to scalar_type. |
| 770 | +
|
| 771 | + Returns |
| 772 | + ------- |
| 773 | + self._scalar_type or NaT |
| 774 | + """ |
| 775 | + if cast_str and isinstance(value, str): |
| 776 | + # NB: Careful about tzawareness |
| 777 | + try: |
| 778 | + value = self._scalar_from_string(value) |
| 779 | + except ValueError as err: |
| 780 | + raise TypeError(msg) from err |
| 781 | + |
| 782 | + elif is_valid_nat_for_dtype(value, self.dtype): |
| 783 | + # GH#18295 |
| 784 | + value = NaT |
| 785 | + |
| 786 | + elif isinstance(value, self._recognized_scalars): |
| 787 | + value = self._scalar_type(value) # type: ignore |
| 788 | + |
| 789 | + else: |
| 790 | + raise TypeError(msg) |
| 791 | + |
| 792 | + return value |
| 793 | + |
760 | 794 | def _validate_listlike(
|
761 | 795 | self, value, opname: str, cast_str: bool = False, allow_object: bool = False,
|
762 | 796 | ):
|
@@ -795,72 +829,42 @@ def _validate_listlike(
|
795 | 829 | return value
|
796 | 830 |
|
797 | 831 | def _validate_searchsorted_value(self, value):
|
798 |
| - if isinstance(value, str): |
799 |
| - try: |
800 |
| - value = self._scalar_from_string(value) |
801 |
| - except ValueError as err: |
802 |
| - raise TypeError( |
803 |
| - "searchsorted requires compatible dtype or scalar" |
804 |
| - ) from err |
805 |
| - |
806 |
| - elif is_valid_nat_for_dtype(value, self.dtype): |
807 |
| - value = NaT |
808 |
| - |
809 |
| - elif isinstance(value, self._recognized_scalars): |
810 |
| - value = self._scalar_type(value) |
811 |
| - |
812 |
| - elif not is_list_like(value): |
813 |
| - raise TypeError(f"Unexpected type for 'value': {type(value)}") |
814 |
| - |
| 832 | + msg = "searchsorted requires compatible dtype or scalar" |
| 833 | + if not is_list_like(value): |
| 834 | + value = self._validate_scalar(value, msg, cast_str=True) |
815 | 835 | else:
|
816 | 836 | # TODO: cast_str? we accept it for scalar
|
817 | 837 | value = self._validate_listlike(value, "searchsorted")
|
818 | 838 |
|
819 | 839 | return self._unbox(value)
|
820 | 840 |
|
821 | 841 | def _validate_setitem_value(self, value):
|
822 |
| - |
| 842 | + msg = ( |
| 843 | + f"'value' should be a '{self._scalar_type.__name__}', 'NaT', " |
| 844 | + f"or array of those. Got '{type(value).__name__}' instead." |
| 845 | + ) |
823 | 846 | if is_list_like(value):
|
824 | 847 | value = self._validate_listlike(value, "setitem", cast_str=True)
|
825 |
| - |
826 |
| - elif isinstance(value, self._recognized_scalars): |
827 |
| - value = self._scalar_type(value) |
828 |
| - elif is_valid_nat_for_dtype(value, self.dtype): |
829 |
| - value = NaT |
830 | 848 | else:
|
831 |
| - msg = ( |
832 |
| - f"'value' should be a '{self._scalar_type.__name__}', 'NaT', " |
833 |
| - f"or array of those. Got '{type(value).__name__}' instead." |
834 |
| - ) |
835 |
| - raise TypeError(msg) |
| 849 | + # TODO: cast_str for consistency? |
| 850 | + value = self._validate_scalar(value, msg, cast_str=False) |
836 | 851 |
|
837 | 852 | self._check_compatible_with(value, setitem=True)
|
838 | 853 | return self._unbox(value)
|
839 | 854 |
|
840 | 855 | def _validate_insert_value(self, value):
|
841 |
| - if isinstance(value, self._recognized_scalars): |
842 |
| - value = self._scalar_type(value) |
843 |
| - elif is_valid_nat_for_dtype(value, self.dtype): |
844 |
| - # GH#18295 |
845 |
| - value = NaT |
846 |
| - else: |
847 |
| - raise TypeError( |
848 |
| - f"cannot insert {type(self).__name__} with incompatible label" |
849 |
| - ) |
| 856 | + msg = f"cannot insert {type(self).__name__} with incompatible label" |
| 857 | + value = self._validate_scalar(value, msg, cast_str=False) |
850 | 858 |
|
851 | 859 | self._check_compatible_with(value, setitem=True)
|
852 | 860 | # TODO: if we dont have compat, should we raise or astype(object)?
|
853 | 861 | # PeriodIndex does astype(object)
|
854 | 862 | return value
|
855 | 863 |
|
856 | 864 | def _validate_where_value(self, other):
|
857 |
| - if is_valid_nat_for_dtype(other, self.dtype): |
858 |
| - other = NaT |
859 |
| - elif isinstance(other, self._recognized_scalars): |
860 |
| - other = self._scalar_type(other) |
861 |
| - elif not is_list_like(other): |
862 |
| - raise TypeError(f"Where requires matching dtype, not {type(other)}") |
863 |
| - |
| 865 | + msg = f"Where requires matching dtype, not {type(other)}" |
| 866 | + if not is_list_like(other): |
| 867 | + other = self._validate_scalar(other, msg) |
864 | 868 | else:
|
865 | 869 | other = self._validate_listlike(other, "where")
|
866 | 870 | self._check_compatible_with(other, setitem=True)
|
|
0 commit comments