Skip to content

Commit 723b8da

Browse files
committed
BUG: read_parquet does not respect index for arrow dtype backend
1 parent ce32601 commit 723b8da

File tree

2 files changed

+53
-8
lines changed

2 files changed

+53
-8
lines changed

pandas/io/parquet.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
from pandas import (
2626
DataFrame,
27+
Index,
2728
MultiIndex,
29+
RangeIndex,
2830
arrays,
2931
get_option,
3032
)
@@ -250,14 +252,28 @@ def read(
250252
if dtype_backend == "pandas":
251253
result = pa_table.to_pandas(**to_pandas_kwargs)
252254
elif dtype_backend == "pyarrow":
253-
result = DataFrame(
254-
{
255-
col_name: arrays.ArrowExtensionArray(pa_col)
256-
for col_name, pa_col in zip(
257-
pa_table.column_names, pa_table.itercolumns()
258-
)
259-
}
260-
)
255+
index_columns = pa_table.schema.pandas_metadata.get("index_columns", [])
256+
result_dc = {
257+
col_name: arrays.ArrowExtensionArray(pa_col)
258+
for col_name, pa_col in zip(
259+
pa_table.column_names, pa_table.itercolumns()
260+
)
261+
}
262+
if len(index_columns) == 1 and isinstance(index_columns[0], dict):
263+
params = index_columns[0]
264+
idx = RangeIndex(
265+
params.get("start"), params.get("stop"), params.get("step")
266+
)
267+
268+
else:
269+
index_data = [
270+
result_dc.pop(index_col) for index_col in index_columns
271+
]
272+
if len(index_data) == 1:
273+
idx = Index(index_data[0], name=index_columns[0])
274+
else:
275+
idx = MultiIndex.from_arrays(index_data, names=index_columns)
276+
result = DataFrame(result_dc, index=idx)
261277
if manager == "array":
262278
result = result._as_manager("array", copy=False)
263279
return result

pandas/tests/io/test_parquet.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pandas.util._test_decorators as td
1919

2020
import pandas as pd
21+
from pandas import RangeIndex
2122
import pandas._testing as tm
2223
from pandas.util.version import Version
2324

@@ -1225,3 +1226,31 @@ def test_bytes_file_name(self, engine):
12251226

12261227
result = read_parquet(path, engine=engine)
12271228
tm.assert_frame_equal(result, df)
1229+
1230+
@pytest.mark.parametrize("index", ["A", ["A", "B"]])
1231+
def test_pyarrow_backed_df_index(self, index, pa):
1232+
# GH#48944
1233+
obj = pd.DataFrame(data={"A": [0, 1], "B": [1, 0], "C": 1})
1234+
df = obj.set_index(index)
1235+
with tm.ensure_clean("test.parquet") as path:
1236+
with open(path.encode(), "wb") as f:
1237+
df.to_parquet(f)
1238+
1239+
with pd.option_context("mode.dtype_backend", "pyarrow"):
1240+
result = read_parquet(path, engine="pyarrow")
1241+
expected = obj.astype("int64[pyarrow]").set_index(index)
1242+
tm.assert_frame_equal(result, expected)
1243+
1244+
def test_pyarrow_backed_df_range_index(self, pa):
1245+
# GH#48944
1246+
df = pd.DataFrame(
1247+
data={"A": [0, 1], "B": [1, 0]}, index=RangeIndex(start=100, stop=102)
1248+
)
1249+
with tm.ensure_clean("test.parquet") as path:
1250+
with open(path.encode(), "wb") as f:
1251+
df.to_parquet(f)
1252+
1253+
with pd.option_context("mode.dtype_backend", "pyarrow"):
1254+
result = read_parquet(path, engine="pyarrow")
1255+
expected = df.astype("int64[pyarrow]")
1256+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)