Skip to content

Commit ac85de8

Browse files
authored
REF: de-duplicate ensure_np_dtype (#54258)
* REF: de-duplicate ensure_np_dtype * REF: simplify as_array * mypy fixup * Missing check in CoW case
1 parent 3565309 commit ac85de8

File tree

3 files changed

+28
-30
lines changed

3 files changed

+28
-30
lines changed

pandas/core/internals/array_manager.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,7 @@
3535
is_object_dtype,
3636
is_timedelta64_ns_dtype,
3737
)
38-
from pandas.core.dtypes.dtypes import (
39-
ExtensionDtype,
40-
NumpyEADtype,
41-
SparseDtype,
42-
)
38+
from pandas.core.dtypes.dtypes import ExtensionDtype
4339
from pandas.core.dtypes.generic import (
4440
ABCDataFrame,
4541
ABCSeries,
@@ -75,6 +71,7 @@
7571
from pandas.core.internals.base import (
7672
DataManager,
7773
SingleDataManager,
74+
ensure_np_dtype,
7875
interleaved_dtype,
7976
)
8077
from pandas.core.internals.blocks import (
@@ -1021,14 +1018,7 @@ def as_array(
10211018
if not dtype:
10221019
dtype = interleaved_dtype([arr.dtype for arr in self.arrays])
10231020

1024-
if isinstance(dtype, SparseDtype):
1025-
dtype = dtype.subtype
1026-
elif isinstance(dtype, NumpyEADtype):
1027-
dtype = dtype.numpy_dtype
1028-
elif isinstance(dtype, ExtensionDtype):
1029-
dtype = np.dtype("object")
1030-
elif dtype == np.dtype(str):
1031-
dtype = np.dtype("object")
1021+
dtype = ensure_np_dtype(dtype)
10321022

10331023
result = np.empty(self.shape_proper, dtype=dtype)
10341024

pandas/core/internals/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
TYPE_CHECKING,
99
Any,
1010
Literal,
11+
cast,
1112
final,
1213
)
1314

@@ -26,6 +27,10 @@
2627
find_common_type,
2728
np_can_hold_element,
2829
)
30+
from pandas.core.dtypes.dtypes import (
31+
ExtensionDtype,
32+
SparseDtype,
33+
)
2934

3035
from pandas.core.base import PandasObject
3136
from pandas.core.construction import extract_array
@@ -356,3 +361,16 @@ def interleaved_dtype(dtypes: list[DtypeObj]) -> DtypeObj | None:
356361
return None
357362

358363
return find_common_type(dtypes)
364+
365+
366+
def ensure_np_dtype(dtype: DtypeObj) -> np.dtype:
367+
# TODO: https://github.com/pandas-dev/pandas/issues/22791
368+
# Give EAs some input on what happens here. Sparse needs this.
369+
if isinstance(dtype, SparseDtype):
370+
dtype = dtype.subtype
371+
dtype = cast(np.dtype, dtype)
372+
elif isinstance(dtype, ExtensionDtype):
373+
dtype = np.dtype("object")
374+
elif dtype == np.dtype(str):
375+
dtype = np.dtype("object")
376+
return dtype

pandas/core/internals/managers.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from pandas.core.internals.base import (
6666
DataManager,
6767
SingleDataManager,
68+
ensure_np_dtype,
6869
interleaved_dtype,
6970
)
7071
from pandas.core.internals.blocks import (
@@ -1623,16 +1624,12 @@ def as_array(
16231624
arr = blk.values.to_numpy( # type: ignore[union-attr]
16241625
dtype=dtype,
16251626
na_value=na_value,
1627+
copy=copy,
16261628
).reshape(blk.shape)
16271629
else:
1628-
arr = np.asarray(blk.get_values())
1629-
if dtype:
1630-
arr = arr.astype(dtype, copy=copy)
1631-
copy = False
1630+
arr = np.array(blk.values, dtype=dtype, copy=copy)
16321631

1633-
if copy:
1634-
arr = arr.copy()
1635-
elif using_copy_on_write():
1632+
if using_copy_on_write() and not copy:
16361633
arr = arr.view()
16371634
arr.flags.writeable = False
16381635
else:
@@ -1666,16 +1663,9 @@ def _interleave(
16661663
[blk.dtype for blk in self.blocks]
16671664
)
16681665

1669-
# TODO: https://github.com/pandas-dev/pandas/issues/22791
1670-
# Give EAs some input on what happens here. Sparse needs this.
1671-
if isinstance(dtype, SparseDtype):
1672-
dtype = dtype.subtype
1673-
dtype = cast(np.dtype, dtype)
1674-
elif isinstance(dtype, ExtensionDtype):
1675-
dtype = np.dtype("object")
1676-
elif dtype == np.dtype(str):
1677-
dtype = np.dtype("object")
1678-
1666+
# error: Argument 1 to "ensure_np_dtype" has incompatible type
1667+
# "Optional[dtype[Any]]"; expected "Union[dtype[Any], ExtensionDtype]"
1668+
dtype = ensure_np_dtype(dtype) # type: ignore[arg-type]
16791669
result = np.empty(self.shape, dtype=dtype)
16801670

16811671
itemmask = np.zeros(self.shape[0])

0 commit comments

Comments
 (0)