Skip to content

Commit 7d7f233

Browse files
authored
INTPYTHON-549 Rework the projection logic to allow reading list-of-struct data structures using a specified schema (#285)
1 parent 577f38d commit 7d7f233

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

bindings/python/pymongoarrow/schema.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,21 @@ def _normalize_mapping(mapping):
6868
def _get_projection(self):
6969
projection = {"_id": False}
7070
for fname, ftype in self.typemap.items():
71-
projection[fname] = self._get_field_projection_value(ftype)
71+
projection = self._get_field_projection_value(fname, ftype, projection)
7272
return projection
7373

74-
def _get_field_projection_value(self, ftype):
74+
def _get_field_projection_value(self, fname, ftype, projection):
7575
value = True
7676
if isinstance(ftype, pa.ListType):
77-
return self._get_field_projection_value(ftype.value_field.type)
77+
return self._get_field_projection_value(fname, ftype.value_field.type, projection)
7878
if isinstance(ftype, pa.StructType):
79-
projection = {}
8079
for nested_ftype in ftype:
81-
projection[nested_ftype.name] = True
82-
value = projection
83-
return value
80+
projection = self._get_field_projection_value(
81+
fname + "." + nested_ftype.name, nested_ftype.type, projection
82+
)
83+
return projection
84+
projection[fname] = value
85+
return projection
8486

8587
def __eq__(self, other):
8688
if isinstance(other, type(self)):

bindings/python/test/test_arrow.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
date64,
3333
decimal256,
3434
field,
35+
float64,
3536
int32,
3637
int64,
3738
large_list,
@@ -415,6 +416,49 @@ def inner(i):
415416
raw_data["nested"] = [inner(i) for i in range(3)]
416417
return schema, Table.from_pydict(raw_data, ArrowSchema(schema))
417418

419+
def test_write_nested_schema_validation(self):
420+
raw_data = {
421+
"_id": [1, 2],
422+
"top": [
423+
{
424+
"middle": {
425+
"value": "string_1",
426+
"bottom": [
427+
{"event": datetime(2012, 1, 1), "value": 1.1},
428+
{"event": datetime(2014, 1, 1), "value": 1.2},
429+
],
430+
}
431+
},
432+
{
433+
"middle": {
434+
"value": "string_2",
435+
"bottom": [
436+
{"event": datetime(2013, 1, 1), "value": 1.2},
437+
{"event": datetime(2019, 1, 1), "value": 1.5},
438+
],
439+
}
440+
},
441+
],
442+
}
443+
444+
schema = {
445+
"_id": int64(),
446+
"top": struct(
447+
{
448+
"middle": struct(
449+
{
450+
"value": string(),
451+
"bottom": list_(struct({"event": timestamp("ms"), "value": float64()})),
452+
}
453+
)
454+
}
455+
),
456+
}
457+
458+
data = Table.from_pydict(raw_data, ArrowSchema(schema))
459+
460+
self.round_trip(data, Schema(schema))
461+
418462
def test_parquet(self):
419463
schema, data = self._create_nested_data()
420464
with tempfile.NamedTemporaryFile(suffix=".parquet") as f:

bindings/python/test/test_schema.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_from_arrow_units(self):
7676

7777
def test_nested_projection(self):
7878
schema = Schema({"_id": int64(), "obj": {"a": int64(), "b": int64()}})
79-
self.assertEqual(schema._get_projection(), {"_id": True, "obj": {"a": True, "b": True}})
79+
self.assertEqual(schema._get_projection(), {"_id": True, "obj.a": True, "obj.b": True})
8080

8181
def test_list_projection(self):
8282
schema = Schema(
@@ -85,7 +85,7 @@ def test_list_projection(self):
8585
"list": list_(struct([field("a", int64()), field("b", int64())])),
8686
}
8787
)
88-
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})
88+
self.assertEqual(schema._get_projection(), {"_id": True, "list.a": True, "list.b": True})
8989

9090
def test_list_of_list_projection(self):
9191
schema = Schema(
@@ -94,14 +94,14 @@ def test_list_of_list_projection(self):
9494
"list": list_(list_(struct([field("a", int64()), field("b", int64())]))),
9595
}
9696
)
97-
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})
97+
self.assertEqual(schema._get_projection(), {"_id": True, "list.a": True, "list.b": True})
9898

9999
def test_py_list_projection(self):
100100
schema = Schema(
101101
{"_id": ObjectId, "list": [(struct([field("a", int64()), field("b", float64())]))]}
102102
)
103103

104-
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})
104+
self.assertEqual(schema._get_projection(), {"_id": True, "list.a": True, "list.b": True})
105105

106106
def test_py_list_with_multiple_fields_raises(self):
107107
with pytest.raises(

0 commit comments

Comments
 (0)