Skip to content

Commit a6d066c

Browse files
fixup pd.array and more testing of string_storage option
1 parent 6247a5b commit a6d066c

File tree

9 files changed

+164
-38
lines changed

9 files changed

+164
-38
lines changed

pandas/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,22 @@ def nullable_string_dtype(request):
11311131
return request.param
11321132

11331133

1134+
@pytest.fixture(
1135+
params=[
1136+
"python",
1137+
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow", min_version="1.0.0")),
1138+
]
1139+
)
1140+
def string_storage(request):
1141+
"""
1142+
Parametrized fixture for pd.options.mode.string_storage.
1143+
1144+
* 'python'
1145+
* 'pyarrow'
1146+
"""
1147+
return request.param
1148+
1149+
11341150
@pytest.fixture(params=tm.BYTES_DTYPES)
11351151
def bytes_dtype(request):
11361152
"""

pandas/core/arrays/string_.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __init__(self, values, copy=False):
295295
super().__init__(values, copy=copy)
296296
# error: Incompatible types in assignment (expression has type "StringDtype",
297297
# variable has type "PandasDtype")
298-
NDArrayBacked.__init__(self, self._ndarray, StringDtype())
298+
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage="python"))
299299
if not isinstance(values, type(self)):
300300
self._validate()
301301

@@ -311,8 +311,9 @@ def _validate(self):
311311

312312
@classmethod
313313
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
314-
if dtype:
315-
assert dtype == "string"
314+
if dtype and not (isinstance(dtype, str) and dtype == "string"):
315+
dtype = pandas_dtype(dtype)
316+
assert isinstance(dtype, StringDtype) and dtype.storage == "python"
316317

317318
from pandas.core.arrays.masked import BaseMaskedArray
318319

@@ -332,7 +333,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
332333
# Manually creating new array avoids the validation step in the __init__, so is
333334
# faster. Refactor need for validation?
334335
new_string_array = cls.__new__(cls)
335-
NDArrayBacked.__init__(new_string_array, result, StringDtype())
336+
NDArrayBacked.__init__(new_string_array, result, StringDtype(storage="python"))
336337

337338
return new_string_array
338339

@@ -501,7 +502,7 @@ def _str_map(
501502
from pandas.arrays import BooleanArray
502503

503504
if dtype is None:
504-
dtype = StringDtype()
505+
dtype = StringDtype(storage="python")
505506
if na_value is None:
506507
na_value = self.dtype.na_value
507508

pandas/core/arrays/string_arrow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_object_dtype,
3636
is_scalar,
3737
is_string_dtype,
38+
pandas_dtype,
3839
)
3940
from pandas.core.dtypes.missing import isna
4041

@@ -154,6 +155,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)
154155

155156
cls._chk_pyarrow_available()
156157

158+
if dtype and not (isinstance(dtype, str) and dtype == "string"):
159+
dtype = pandas_dtype(dtype)
160+
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
161+
157162
if isinstance(scalars, BaseMaskedArray):
158163
# avoid costly conversion to object dtype in ensure_string_array and
159164
# numerical issues with Float32Dtype

pandas/core/construction.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,22 @@ def array(
113113
114114
Currently, pandas will infer an extension dtype for sequences of
115115
116-
============================== =====================================
116+
============================== =======================================
117117
Scalar Type Array Type
118-
============================== =====================================
118+
============================== =======================================
119119
:class:`pandas.Interval` :class:`pandas.arrays.IntervalArray`
120120
:class:`pandas.Period` :class:`pandas.arrays.PeriodArray`
121121
:class:`datetime.datetime` :class:`pandas.arrays.DatetimeArray`
122122
:class:`datetime.timedelta` :class:`pandas.arrays.TimedeltaArray`
123123
:class:`int` :class:`pandas.arrays.IntegerArray`
124124
:class:`float` :class:`pandas.arrays.FloatingArray`
125-
:class:`str` :class:`pandas.arrays.StringArray`
125+
:class:`str` :class:`pandas.arrays.StringArray` or
126+
:class:`pandas.arrays.ArrowStringArray`
126127
:class:`bool` :class:`pandas.arrays.BooleanArray`
127-
============================== =====================================
128+
============================== =======================================
129+
130+
The ExtensionArray created when the scalar type is :class:`str` is determined by
131+
pd.options.mode.string_storage if the dtype is not explicitly given.
128132
129133
For all other cases, NumPy's usual inference rules will be used.
130134
@@ -240,6 +244,14 @@ def array(
240244
['a', <NA>, 'c']
241245
Length: 3, dtype: string[python]
242246
247+
>>> with pd.option_context("string_storage", "pyarrow"):
248+
... arr = pd.array(["a", None, "c"])
249+
...
250+
>>> arr
251+
<ArrowStringArray>
252+
['a', <NA>, 'c']
253+
Length: 3, dtype: string[pyarrow]
254+
243255
>>> pd.array([pd.Period('2000', freq="D"), pd.Period("2000", freq="D")])
244256
<PeriodArray>
245257
['2000-01-01', '2000-01-01']
@@ -292,10 +304,10 @@ def array(
292304
IntegerArray,
293305
IntervalArray,
294306
PandasArray,
295-
StringArray,
296307
TimedeltaArray,
297308
period_array,
298309
)
310+
from pandas.core.arrays.string_ import StringDtype
299311

300312
if lib.is_scalar(data):
301313
msg = f"Cannot pass scalar '{data}' to 'pandas.array'."
@@ -345,7 +357,8 @@ def array(
345357
return TimedeltaArray._from_sequence(data, copy=copy)
346358

347359
elif inferred_dtype == "string":
348-
return StringArray._from_sequence(data, copy=copy)
360+
# StringArray/ArrowStringArray depending on pd.options.mode.string_storage
361+
return StringDtype().construct_array_type()._from_sequence(data, copy=copy)
349362

350363
elif inferred_dtype == "integer":
351364
return IntegerArray._from_sequence(data, copy=copy)

pandas/tests/arrays/string_/test_string_arrow.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,33 @@
88

99
pa = pytest.importorskip("pyarrow", minversion="1.0.0")
1010

11+
from pandas.core.arrays.string_ import (
12+
StringArray,
13+
StringDtype,
14+
)
1115
from pandas.core.arrays.string_arrow import ArrowStringArray
1216

1317

1418
def test_eq_all_na():
15-
a = pd.array([pd.NA, pd.NA], dtype=pd.StringDtype("pyarrow"))
19+
a = pd.array([pd.NA, pd.NA], dtype=StringDtype("pyarrow"))
1620
result = a == a
1721
expected = pd.array([pd.NA, pd.NA], dtype="boolean")
1822
tm.assert_extension_array_equal(result, expected)
1923

2024

21-
def test_config():
22-
# python by default
23-
assert pd.StringDtype().storage == "python"
24-
arr = pd.array(["a", "b"])
25-
assert arr.dtype.storage == "python"
25+
def test_config(string_storage):
26+
with pd.option_context("string_storage", string_storage):
27+
assert StringDtype().storage == string_storage
28+
result = pd.array(["a", "b"])
29+
assert result.dtype.storage == string_storage
2630

27-
with pd.option_context("mode.string_storage", "pyarrow"):
28-
assert pd.StringDtype().storage == "pyarrow"
29-
arr = pd.array(["a", "b"])
30-
assert arr.dtype.storage == "pyarrow"
31+
expected = (
32+
StringDtype(string_storage).construct_array_type()._from_sequence(["a", "b"])
33+
)
34+
tm.assert_equal(result, expected)
3135

36+
37+
def test_config_bad_storage_raises():
3238
msg = re.escape("Value must be one of python|pyarrow")
3339
with pytest.raises(ValueError, match=msg):
3440
pd.options.mode.string_storage = "foo"
@@ -50,3 +56,51 @@ def test_constructor_not_string_type_raises(array, chunked):
5056
)
5157
with pytest.raises(ValueError, match=msg):
5258
ArrowStringArray(arr)
59+
60+
61+
def test_from_sequence_wrong_dtype_raises():
62+
with pd.option_context("string_storage", "python"):
63+
ArrowStringArray._from_sequence(["a", None, "c"], dtype="string")
64+
65+
with pd.option_context("string_storage", "pyarrow"):
66+
ArrowStringArray._from_sequence(["a", None, "c"], dtype="string")
67+
68+
with pytest.raises(AssertionError, match=None):
69+
ArrowStringArray._from_sequence(["a", None, "c"], dtype="string[python]")
70+
71+
ArrowStringArray._from_sequence(["a", None, "c"], dtype="string[pyarrow]")
72+
73+
with pytest.raises(AssertionError, match=None):
74+
with pd.option_context("string_storage", "python"):
75+
ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
76+
77+
with pd.option_context("string_storage", "pyarrow"):
78+
ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
79+
80+
with pytest.raises(AssertionError, match=None):
81+
ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype("python"))
82+
83+
ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype("pyarrow"))
84+
85+
with pd.option_context("string_storage", "python"):
86+
StringArray._from_sequence(["a", None, "c"], dtype="string")
87+
88+
with pd.option_context("string_storage", "pyarrow"):
89+
StringArray._from_sequence(["a", None, "c"], dtype="string")
90+
91+
StringArray._from_sequence(["a", None, "c"], dtype="string[python]")
92+
93+
with pytest.raises(AssertionError, match=None):
94+
StringArray._from_sequence(["a", None, "c"], dtype="string[pyarrow]")
95+
96+
with pd.option_context("string_storage", "python"):
97+
StringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
98+
99+
with pytest.raises(AssertionError, match=None):
100+
with pd.option_context("string_storage", "pyarrow"):
101+
StringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
102+
103+
StringArray._from_sequence(["a", None, "c"], dtype=StringDtype("python"))
104+
105+
with pytest.raises(AssertionError, match=None):
106+
StringArray._from_sequence(["a", None, "c"], dtype=StringDtype("pyarrow"))

pandas/tests/arrays/test_array.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
IntegerArray,
1919
IntervalArray,
2020
SparseArray,
21-
StringArray,
2221
TimedeltaArray,
2322
)
2423
from pandas.core.arrays import (
@@ -132,8 +131,16 @@
132131
([1, None], "Int16", pd.array([1, None], dtype="Int16")),
133132
(pd.Series([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
134133
# String
135-
(["a", None], "string", StringArray._from_sequence(["a", None])),
136-
(["a", None], pd.StringDtype(), StringArray._from_sequence(["a", None])),
134+
(
135+
["a", None],
136+
"string",
137+
pd.StringDtype().construct_array_type()._from_sequence(["a", None]),
138+
),
139+
(
140+
["a", None],
141+
pd.StringDtype(),
142+
pd.StringDtype().construct_array_type()._from_sequence(["a", None]),
143+
),
137144
# Boolean
138145
([True, None], "boolean", BooleanArray._from_sequence([True, None])),
139146
([True, None], pd.BooleanDtype(), BooleanArray._from_sequence([True, None])),
@@ -253,8 +260,14 @@ def test_array_copy():
253260
([1, 2.0], FloatingArray._from_sequence([1.0, 2.0])),
254261
([1, np.nan, 2.0], FloatingArray._from_sequence([1.0, None, 2.0])),
255262
# string
256-
(["a", "b"], StringArray._from_sequence(["a", "b"])),
257-
(["a", None], StringArray._from_sequence(["a", None])),
263+
(
264+
["a", "b"],
265+
pd.StringDtype().construct_array_type()._from_sequence(["a", "b"]),
266+
),
267+
(
268+
["a", None],
269+
pd.StringDtype().construct_array_type()._from_sequence(["a", None]),
270+
),
258271
# Boolean
259272
([True, False], BooleanArray._from_sequence([True, False])),
260273
([True, None], BooleanArray._from_sequence([True, None])),

pandas/tests/arrays/test_datetimelike.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_searchsorted(self):
298298
assert result == 10
299299

300300
@pytest.mark.parametrize("box", [None, "index", "series"])
301-
def test_searchsorted_castable_strings(self, arr1d, box, request):
301+
def test_searchsorted_castable_strings(self, arr1d, box, request, string_storage):
302302
if isinstance(arr1d, DatetimeArray):
303303
tz = arr1d.tz
304304
ts1, ts2 = arr1d[1:3]
@@ -341,14 +341,17 @@ def test_searchsorted_castable_strings(self, arr1d, box, request):
341341
):
342342
arr.searchsorted("foo")
343343

344-
with pytest.raises(
345-
TypeError,
346-
match=re.escape(
347-
f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', "
348-
"or array of those. Got 'StringArray' instead."
349-
),
350-
):
351-
arr.searchsorted([str(arr[1]), "baz"])
344+
arr_type = "StringArray" if string_storage == "python" else "ArrowStringArray"
345+
346+
with pd.option_context("string_storage", string_storage):
347+
with pytest.raises(
348+
TypeError,
349+
match=re.escape(
350+
f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', "
351+
f"or array of those. Got '{arr_type}' instead."
352+
),
353+
):
354+
arr.searchsorted([str(arr[1]), "baz"])
352355

353356
def test_getitem_near_implementation_bounds(self):
354357
# We only check tz-naive for DTA bc the bounds are slightly different

pandas/tests/series/methods/test_astype.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pandas._libs.tslibs import iNaT
1313
import pandas.util._test_decorators as td
1414

15+
import pandas as pd
1516
from pandas import (
1617
NA,
1718
Categorical,
@@ -377,17 +378,34 @@ class TestAstypeString:
377378
# currently no way to parse IntervalArray from a list of strings
378379
],
379380
)
380-
def test_astype_string_to_extension_dtype_roundtrip(self, data, dtype, request):
381+
def test_astype_string_to_extension_dtype_roundtrip(
382+
self, data, dtype, request, string_storage
383+
):
381384
if dtype == "boolean" or (
382385
dtype in ("period[M]", "datetime64[ns]", "timedelta64[ns]") and NaT in data
383386
):
384387
mark = pytest.mark.xfail(
385388
reason="TODO StringArray.astype() with missing values #GH40566"
386389
)
387390
request.node.add_marker(mark)
391+
392+
if string_storage == "pyarrow" and dtype in (
393+
"category",
394+
"datetime64[ns]",
395+
"datetime64[ns, US/Eastern]",
396+
"UInt16",
397+
"period[M]",
398+
):
399+
mark = pytest.mark.xfail(
400+
reason="TypeError: Cannot interpret ... as a data type"
401+
)
402+
request.node.add_marker(mark)
403+
388404
# GH-40351
389405
s = Series(data, dtype=dtype)
390-
tm.assert_series_equal(s, s.astype("string").astype(dtype))
406+
with pd.option_context("string_storage", string_storage):
407+
result = s.astype("string").astype(dtype)
408+
tm.assert_series_equal(result, s)
391409

392410

393411
class TestAstypeCategorical:

pandas/tests/strings/test_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
MultiIndex,
77
Series,
88
_testing as tm,
9+
get_option,
910
)
1011
from pandas.core import strings as strings
1112

@@ -128,7 +129,9 @@ def test_api_per_method(
128129
def test_api_for_categorical(any_string_method, any_string_dtype, request):
129130
# https://github.com/pandas-dev/pandas/issues/10661
130131

131-
if any_string_dtype == "string[pyarrow]":
132+
if any_string_dtype == "string[pyarrow]" or (
133+
any_string_dtype == "string" and get_option("string_storage") == "pyarrow"
134+
):
132135
# unsupported operand type(s) for +: 'ArrowStringArray' and 'str'
133136
mark = pytest.mark.xfail(raises=TypeError, reason="Not Implemented")
134137
request.node.add_marker(mark)

0 commit comments

Comments
 (0)