Skip to content

Commit ce4c1d1

Browse files
authored
ARROW-241 Allow list in Schema as an alias for pa.list_ (#223)
1 parent 1c5cdd2 commit ce4c1d1

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

bindings/python/pymongoarrow/types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,10 @@ def _normalize_typeid(typeid, field_name):
288288
fields.append((sub_field_name, _normalize_typeid(sub_typeid, sub_field_name)))
289289
return struct(fields)
290290
if isinstance(typeid, list):
291-
return list_(_normalize_typeid(type(typeid[0]), "0"))
291+
if len(typeid) != 1:
292+
msg = f"list field in schema must contain exactly one element, not {len(typeid)}"
293+
raise ValueError(msg)
294+
return list_(_normalize_typeid(typeid[0], "0"))
292295
if _is_typeid_supported(typeid):
293296
normalizer = _TYPE_NORMALIZER_FACTORY[typeid]
294297
return normalizer(typeid)

bindings/python/test/test_schema.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from datetime import datetime
1515
from unittest import TestCase
1616

17+
import pytest
1718
from bson import Binary, Code, Decimal128, Int64, ObjectId
1819
from pyarrow import Table, field, float64, int64, list_, struct, timestamp
1920
from pyarrow import schema as ArrowSchema
@@ -94,3 +95,22 @@ def test_list_of_list_projection(self):
9495
}
9596
)
9697
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})
98+
99+
def test_py_list_projection(self):
100+
schema = Schema(
101+
{"_id": ObjectId, "list": [(struct([field("a", int64()), field("b", float64())]))]}
102+
)
103+
104+
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})
105+
106+
def test_py_list_with_multiple_fields_raises(self):
107+
with pytest.raises(
108+
ValueError, match="list field in schema must contain exactly one element, not 2"
109+
):
110+
_ = Schema({"_id": ObjectId, "list": [([field("a", int64()), field("b", float64())])]})
111+
112+
def test_py_empty_list_raises(self):
113+
with pytest.raises(
114+
ValueError, match="list field in schema must contain exactly one element, not 0"
115+
):
116+
_ = Schema({"_id": ObjectId, "list": []})

0 commit comments

Comments
 (0)