Skip to content

REF: share ExtensionIndex astype, __getitem__ with Index #44059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 49 additions & 21 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@
deprecate_nonkeyword_arguments,
doc,
)
from pandas.util._exceptions import find_stack_level
from pandas.util._exceptions import (
find_stack_level,
rewrite_exception,
)

from pandas.core.dtypes.cast import (
can_hold_element,
Expand Down Expand Up @@ -985,20 +988,40 @@ def astype(self, dtype, copy=True):
dtype = pandas_dtype(dtype)

if is_dtype_equal(self.dtype, dtype):
# Ensure that self.astype(self.dtype) is self
return self.copy() if copy else self

if (
self.dtype == np.dtype("M8[ns]")
and isinstance(dtype, np.dtype)
and dtype.kind == "M"
and dtype != np.dtype("M8[ns]")
):
# For now DatetimeArray supports this by unwrapping ndarray,
# but DatetimeIndex doesn't
raise TypeError(f"Cannot cast {type(self).__name__} to dtype")

values = self._data
if isinstance(values, ExtensionArray):
with rewrite_exception(type(values).__name__, type(self).__name__):
new_values = values.astype(dtype, copy=copy)

elif isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
new_values = cls._from_sequence(self, dtype=dtype, copy=False)
return Index(new_values, dtype=dtype, copy=copy, name=self.name)
# Note: for RangeIndex and CategoricalDtype self vs self._values
# behaves differently here.
new_values = cls._from_sequence(self, dtype=dtype, copy=copy)

try:
casted = self._values.astype(dtype, copy=copy)
except (TypeError, ValueError) as err:
raise TypeError(
f"Cannot cast {type(self).__name__} to dtype {dtype}"
) from err
return Index(casted, name=self.name, dtype=dtype)
else:
try:
new_values = values.astype(dtype, copy=copy)
except (TypeError, ValueError) as err:
raise TypeError(
f"Cannot cast {type(self).__name__} to dtype {dtype}"
) from err

# pass copy=False because any copying will be done in the astype above
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)

_index_shared_docs[
"take"
Expand Down Expand Up @@ -4870,8 +4893,6 @@ def __getitem__(self, key):
corresponding `Index` subclass.

"""
# There's no custom logic to be implemented in __getslice__, so it's
# not overloaded intentionally.
getitem = self._data.__getitem__

if is_scalar(key):
Expand All @@ -4880,25 +4901,32 @@ def __getitem__(self, key):

if isinstance(key, slice):
# This case is separated from the conditional above to avoid
# pessimization of basic indexing.
# pessimization com.is_bool_indexer and ndim checks.
result = getitem(key)
# Going through simple_new for performance.
return type(self)._simple_new(result, name=self._name)

if com.is_bool_indexer(key):
# if we have list[bools, length=1e5] then doing this check+convert
# takes 166 µs + 2.1 ms and cuts the ndarray.__getitem__
# time below from 3.8 ms to 496 µs
# if we already have ndarray[bool], the overhead is 1.4 µs or .25%
key = np.asarray(key, dtype=bool)

result = getitem(key)
if not is_scalar(result):
if np.ndim(result) > 1:
deprecate_ndim_indexing(result)
return result
# NB: Using _constructor._simple_new would break if MultiIndex
# didn't override __getitem__
return self._constructor._simple_new(result, name=self._name)
else:
# Because we ruled out integer above, we always get an arraylike here
if result.ndim > 1:
deprecate_ndim_indexing(result)
if hasattr(result, "_ndarray"):
# i.e. NDArrayBackedExtensionArray
# Unpack to ndarray for MPL compat
return result._ndarray
return result

# NB: Using _constructor._simple_new would break if MultiIndex
# didn't override __getitem__
return self._constructor._simple_new(result, name=self._name)

def _getitem_slice(self: _IndexT, slobj: slice) -> _IndexT:
"""
Fastpath for __getitem__ when we know we have a slice.
Expand Down
49 changes: 0 additions & 49 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@
cache_readonly,
doc,
)
from pandas.util._exceptions import rewrite_exception

from pandas.core.dtypes.common import (
is_dtype_equal,
pandas_dtype,
)
from pandas.core.dtypes.generic import ABCDataFrame

from pandas.core.arrays import IntervalArray
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.indexers import deprecate_ndim_indexing
from pandas.core.indexes.base import Index

_T = TypeVar("_T", bound="NDArrayBackedExtensionIndex")
Expand Down Expand Up @@ -138,22 +132,6 @@ class ExtensionIndex(Index):

_data: IntervalArray | NDArrayBackedExtensionArray

# ---------------------------------------------------------------------
# NDarray-Like Methods

def __getitem__(self, key):
result = self._data[key]
if isinstance(result, type(self._data)):
if result.ndim == 1:
return type(self)(result, name=self._name)
# Unpack to ndarray for MPL compat

result = result._ndarray

# Includes cases where we get a 2D ndarray back for MPL compat
deprecate_ndim_indexing(result)
return result

# ---------------------------------------------------------------------

def insert(self, loc: int, item) -> Index:
Expand Down Expand Up @@ -204,33 +182,6 @@ def map(self, mapper, na_action=None):
except Exception:
return self.astype(object).map(mapper)

@doc(Index.astype)
def astype(self, dtype, copy: bool = True) -> Index:
dtype = pandas_dtype(dtype)
if is_dtype_equal(self.dtype, dtype):
if not copy:
# Ensure that self.astype(self.dtype) is self
return self
return self.copy()

# error: Non-overlapping equality check (left operand type: "dtype[Any]", right
# operand type: "Literal['M8[ns]']")
if (
isinstance(self.dtype, np.dtype)
and isinstance(dtype, np.dtype)
and dtype.kind == "M"
and dtype != "M8[ns]" # type: ignore[comparison-overlap]
):
# For now Datetime supports this by unwrapping ndarray, but DTI doesn't
raise TypeError(f"Cannot cast {type(self).__name__} to dtype")

with rewrite_exception(type(self._data).__name__, type(self).__name__):
new_values = self._data.astype(dtype, copy=copy)

# pass copy=False because any copying will be done in the
# _data.astype call above
return Index(new_values, dtype=new_values.dtype, name=self.name, copy=False)

@cache_readonly
def _isnan(self) -> npt.NDArray[np.bool_]:
# error: Incompatible return value type (got "ExtensionArray", expected
Expand Down