Skip to content

Commit 3d7cc02

Browse files
authored
REF: share more ExtensionIndex methods (#43992)
1 parent 1d74b3e commit 3d7cc02

File tree

4 files changed

+15
-63
lines changed

4 files changed

+15
-63
lines changed

pandas/core/indexes/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def _outer_indexer(
356356

357357
_typ: str = "index"
358358
_data: ExtensionArray | np.ndarray
359+
_data_cls: type[np.ndarray] | type[ExtensionArray] = np.ndarray
359360
_id: object | None = None
360361
_name: Hashable = None
361362
# MultiIndex.levels previously allowed setting the index name. We
@@ -640,7 +641,7 @@ def _simple_new(cls: type[_IndexT], values, name: Hashable = None) -> _IndexT:
640641
641642
Must be careful not to recurse.
642643
"""
643-
assert isinstance(values, np.ndarray), type(values)
644+
assert isinstance(values, cls._data_cls), type(values)
644645

645646
result = object.__new__(cls)
646647
result._data = values
@@ -5020,6 +5021,14 @@ def equals(self, other: Any) -> bool:
50205021
# d-level MultiIndex can equal d-tuple Index
50215022
return other.equals(self)
50225023

5024+
if isinstance(self._values, ExtensionArray):
5025+
# Dispatch to the ExtensionArray's .equals method.
5026+
if not isinstance(other, type(self)):
5027+
return False
5028+
5029+
earr = cast(ExtensionArray, self._data)
5030+
return earr.equals(other._data)
5031+
50235032
if is_extension_array_dtype(other.dtype):
50245033
# All EA-backed Index subclasses override equals
50255034
return other.equals(self)

pandas/core/indexes/datetimelike.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
_index_doc_kwargs = dict(ibase._index_doc_kwargs)
7070

7171
_T = TypeVar("_T", bound="DatetimeIndexOpsMixin")
72+
_TDT = TypeVar("_TDT", bound="DatetimeTimedeltaMixin")
7273

7374

7475
@inherit_names(
@@ -529,7 +530,7 @@ def _can_fast_union(self: _T, other: _T) -> bool:
529530
# Only need to "adjoin", not overlap
530531
return (right_start == left_end + freq) or right_start in left
531532

532-
def _fast_union(self: _T, other: _T, sort=None) -> _T:
533+
def _fast_union(self: _TDT, other: _TDT, sort=None) -> _TDT:
533534
# Caller is responsible for ensuring self and other are non-empty
534535

535536
# to make our life easier, "sort" the two ranges

pandas/core/indexes/extension.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,14 @@
33
"""
44
from __future__ import annotations
55

6-
from typing import (
7-
Hashable,
8-
TypeVar,
9-
)
6+
from typing import TypeVar
107

118
import numpy as np
129

1310
from pandas._typing import (
1411
ArrayLike,
1512
npt,
1613
)
17-
from pandas.compat.numpy import function as nv
1814
from pandas.util._decorators import (
1915
cache_readonly,
2016
doc,
@@ -27,13 +23,7 @@
2723
)
2824
from pandas.core.dtypes.generic import ABCDataFrame
2925

30-
from pandas.core.arrays import (
31-
Categorical,
32-
DatetimeArray,
33-
IntervalArray,
34-
PeriodArray,
35-
TimedeltaArray,
36-
)
26+
from pandas.core.arrays import IntervalArray
3727
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
3828
from pandas.core.indexers import deprecate_ndim_indexing
3929
from pandas.core.indexes.base import Index
@@ -148,38 +138,6 @@ class ExtensionIndex(Index):
148138

149139
_data: IntervalArray | NDArrayBackedExtensionArray
150140

151-
_data_cls: (
152-
type[Categorical]
153-
| type[DatetimeArray]
154-
| type[TimedeltaArray]
155-
| type[PeriodArray]
156-
| type[IntervalArray]
157-
)
158-
159-
@classmethod
160-
def _simple_new(
161-
cls,
162-
array: IntervalArray | NDArrayBackedExtensionArray,
163-
name: Hashable = None,
164-
):
165-
"""
166-
Construct from an ExtensionArray of the appropriate type.
167-
168-
Parameters
169-
----------
170-
array : ExtensionArray
171-
name : Label, default None
172-
Attached as result.name
173-
"""
174-
assert isinstance(array, cls._data_cls), type(array)
175-
176-
result = object.__new__(cls)
177-
result._data = array
178-
result._name = name
179-
result._cache = {}
180-
result._reset_identity()
181-
return result
182-
183141
# ---------------------------------------------------------------------
184142
# NDarray-Like Methods
185143

@@ -198,11 +156,6 @@ def __getitem__(self, key):
198156

199157
# ---------------------------------------------------------------------
200158

201-
def repeat(self, repeats, axis=None):
202-
nv.validate_repeat((), {"axis": axis})
203-
result = self._data.repeat(repeats, axis=axis)
204-
return type(self)._simple_new(result, name=self.name)
205-
206159
def insert(self, loc: int, item) -> Index:
207160
"""
208161
Make new Index inserting new item at location. Follows
@@ -284,17 +237,6 @@ def _isnan(self) -> npt.NDArray[np.bool_]:
284237
# "ndarray")
285238
return self._data.isna() # type: ignore[return-value]
286239

287-
@doc(Index.equals)
288-
def equals(self, other) -> bool:
289-
# Dispatch to the ExtensionArray's .equals method.
290-
if self.is_(other):
291-
return True
292-
293-
if not isinstance(other, type(self)):
294-
return False
295-
296-
return self._data.equals(other._data)
297-
298240

299241
class NDArrayBackedExtensionIndex(ExtensionIndex):
300242
"""

pandas/core/indexes/multi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2161,7 +2161,7 @@ def repeat(self, repeats: int, axis=None) -> MultiIndex:
21612161
return MultiIndex(
21622162
levels=self.levels,
21632163
codes=[
2164-
level_codes.view(np.ndarray).astype(np.intp).repeat(repeats)
2164+
level_codes.view(np.ndarray).astype(np.intp, copy=False).repeat(repeats)
21652165
for level_codes in self.codes
21662166
],
21672167
names=self.names,

0 commit comments

Comments
 (0)