Skip to content

Commit c58ed2f

Browse files
Support for nested ObjectIDs in polars conversion (#220)
* added support for nested fields in polars conversion _arrow_to_polars currently has no support to cast extension types for nested fields. This prohibits ObjectIDs to be read in case they are in nested fields. * add support for List type * add support for list type + reformat code * add test for struct and list in _arrow_to_polars * add types to functions * remove return type for _arrow_to_polars --------- Co-authored-by: Lazar Gugleta <[email protected]>
1 parent 2859a92 commit c58ed2f

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -295,36 +295,38 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
295295
)
296296

297297

298-
def _cast_away_extension_types_on_array(array: pa.Array) -> pa.Array:
299-
"""Return an Array where ExtensionTypes have been cast to their base pyarrow types"""
300-
if isinstance(array.type, pa.ExtensionType):
301-
return array.cast(array.type.storage_type)
302-
# elif pa.types.is_struct(field.type):
303-
# ...
304-
# elif pa.types.is_list(field.type):
305-
# ...
306-
return array
307-
308-
309-
def _cast_away_extension_types_on_table(table: pa.Table) -> pa.Table:
310-
"""Given arrow_table that may ExtensionTypes, cast these to the base pyarrow types"""
311-
# Convert all fields in the Arrow table
312-
converted_fields = [
313-
_cast_away_extension_types_on_array(table.column(i)) for i in range(table.num_columns)
314-
]
315-
# Reconstruct the Arrow table
316-
return pa.Table.from_arrays(converted_fields, names=table.column_names)
317-
318-
319-
def _arrow_to_polars(arrow_table):
298+
def _cast_away_extension_type(field: pa.field) -> pa.field:
299+
if isinstance(field.type, pa.ExtensionType):
300+
field_without_extension = pa.field(field.name, field.type.storage_type)
301+
elif isinstance(field.type, pa.StructType):
302+
field_without_extension = pa.field(
303+
field.name,
304+
pa.struct([_cast_away_extension_type(nested_field) for nested_field in field.type]),
305+
)
306+
elif isinstance(field.type, pa.ListType):
307+
field_without_extension = pa.field(
308+
field.name, pa.list_(_cast_away_extension_type(field.type.value_field))
309+
)
310+
else:
311+
field_without_extension = field
312+
313+
return field_without_extension
314+
315+
316+
def _arrow_to_polars(arrow_table: pa.Table):
320317
"""Helper function that converts an Arrow Table to a Polars DataFrame.
321318
322319
Note: Polars lacks ExtensionTypes. We cast them to their base arrow classes.
323320
"""
324321
if pl is None:
325322
msg = "polars is not installed. Try pip install polars."
326323
raise ValueError(msg)
327-
arrow_table_without_extensions = _cast_away_extension_types_on_table(arrow_table)
324+
325+
schema_without_extensions = pa.schema(
326+
[_cast_away_extension_type(field) for field in arrow_table.schema]
327+
)
328+
arrow_table_without_extensions = arrow_table.cast(schema_without_extensions)
329+
328330
return pl.from_arrow(arrow_table_without_extensions)
329331

330332

bindings/python/test/test_polars.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def test_arrow_to_polars(self):
228228
"str": [str(i) for i in range(2)],
229229
"int": [i for i in range(2)],
230230
"bool": [True, False],
231+
"struct": [{"objId": bson.ObjectId().binary, "str1": str(i)} for i in range(2)],
232+
"list": [[str(i), str(i + 1)] for i in range(2)],
231233
"Binary": [b"1", b"23"],
232234
"ObjectId": [bson.ObjectId().binary, bson.ObjectId().binary],
233235
"Decimal128": [bson.Decimal128(str(i)).bid for i in range(2)],
@@ -241,9 +243,9 @@ def test_arrow_to_polars(self):
241243
self.assertEqual(len(arrow_table_in), res.raw_result["insertedCount"])
242244
df_out = find_polars_all(self.coll, query={}, schema=Schema(arrow_schema))
243245

244-
# Sanity check: compare with cast_away_extension_types_on_table
245-
arrow_cast = api._cast_away_extension_types_on_table(arrow_table_in)
246-
assert_frame_equal(df_out, pl.from_arrow(arrow_cast))
246+
# Sanity check: compare with _arrow_to_polars
247+
df_actual_output = api._arrow_to_polars(arrow_table_in)
248+
assert_frame_equal(df_out, df_actual_output)
247249

248250
def test_exceptions_for_unsupported_polar_types(self):
249251
"""Confirm exceptions thrown are expected.

0 commit comments

Comments
 (0)