Skip to content

REF: share more ExtensionIndex methods #43992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def _outer_indexer(

_typ: str = "index"
_data: ExtensionArray | np.ndarray
_data_cls: type[np.ndarray] | type[ExtensionArray] = np.ndarray
_id: object | None = None
_name: Hashable = None
# MultiIndex.levels previously allowed setting the index name. We
Expand Down Expand Up @@ -640,7 +641,7 @@ def _simple_new(cls: type[_IndexT], values, name: Hashable = None) -> _IndexT:

Must be careful not to recurse.
"""
assert isinstance(values, np.ndarray), type(values)
assert isinstance(values, cls._data_cls), type(values)

result = object.__new__(cls)
result._data = values
Expand Down Expand Up @@ -5020,6 +5021,14 @@ def equals(self, other: Any) -> bool:
# d-level MultiIndex can equal d-tuple Index
return other.equals(self)

if isinstance(self._values, ExtensionArray):
# Dispatch to the ExtensionArray's .equals method.
if not isinstance(other, type(self)):
return False

earr = cast(ExtensionArray, self._data)
return earr.equals(other._data)

if is_extension_array_dtype(other.dtype):
# All EA-backed Index subclasses override equals
return other.equals(self)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
_index_doc_kwargs = dict(ibase._index_doc_kwargs)

_T = TypeVar("_T", bound="DatetimeIndexOpsMixin")
_TDT = TypeVar("_TDT", bound="DatetimeTimedeltaMixin")


@inherit_names(
Expand Down Expand Up @@ -529,7 +530,7 @@ def _can_fast_union(self: _T, other: _T) -> bool:
# Only need to "adjoin", not overlap
return (right_start == left_end + freq) or right_start in left

def _fast_union(self: _T, other: _T, sort=None) -> _T:
def _fast_union(self: _TDT, other: _TDT, sort=None) -> _TDT:
# Caller is responsible for ensuring self and other are non-empty

# to make our life easier, "sort" the two ranges
Expand Down
62 changes: 2 additions & 60 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
"""
from __future__ import annotations

from typing import (
Hashable,
TypeVar,
)
from typing import TypeVar

import numpy as np

from pandas._typing import (
ArrayLike,
npt,
)
from pandas.compat.numpy import function as nv
from pandas.util._decorators import (
cache_readonly,
doc,
Expand All @@ -27,13 +23,7 @@
)
from pandas.core.dtypes.generic import ABCDataFrame

from pandas.core.arrays import (
Categorical,
DatetimeArray,
IntervalArray,
PeriodArray,
TimedeltaArray,
)
from pandas.core.arrays import IntervalArray
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.indexers import deprecate_ndim_indexing
from pandas.core.indexes.base import Index
Expand Down Expand Up @@ -148,38 +138,6 @@ class ExtensionIndex(Index):

_data: IntervalArray | NDArrayBackedExtensionArray

_data_cls: (
type[Categorical]
| type[DatetimeArray]
| type[TimedeltaArray]
| type[PeriodArray]
| type[IntervalArray]
)

@classmethod
def _simple_new(
cls,
array: IntervalArray | NDArrayBackedExtensionArray,
name: Hashable = None,
):
"""
Construct from an ExtensionArray of the appropriate type.

Parameters
----------
array : ExtensionArray
name : Label, default None
Attached as result.name
"""
assert isinstance(array, cls._data_cls), type(array)

result = object.__new__(cls)
result._data = array
result._name = name
result._cache = {}
result._reset_identity()
return result

# ---------------------------------------------------------------------
# NDarray-Like Methods

Expand All @@ -198,11 +156,6 @@ def __getitem__(self, key):

# ---------------------------------------------------------------------

def repeat(self, repeats, axis=None):
nv.validate_repeat((), {"axis": axis})
result = self._data.repeat(repeats, axis=axis)
return type(self)._simple_new(result, name=self.name)

def insert(self, loc: int, item) -> Index:
"""
Make new Index inserting new item at location. Follows
Expand Down Expand Up @@ -284,17 +237,6 @@ def _isnan(self) -> npt.NDArray[np.bool_]:
# "ndarray")
return self._data.isna() # type: ignore[return-value]

@doc(Index.equals)
def equals(self, other) -> bool:
# Dispatch to the ExtensionArray's .equals method.
if self.is_(other):
return True

if not isinstance(other, type(self)):
return False

return self._data.equals(other._data)


class NDArrayBackedExtensionIndex(ExtensionIndex):
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,7 +2161,7 @@ def repeat(self, repeats: int, axis=None) -> MultiIndex:
return MultiIndex(
levels=self.levels,
codes=[
level_codes.view(np.ndarray).astype(np.intp).repeat(repeats)
level_codes.view(np.ndarray).astype(np.intp, copy=False).repeat(repeats)
for level_codes in self.codes
],
names=self.names,
Expand Down