Skip to content

Commit 79c834f

Browse files
authored
ARROW-233 Add an optional bool flag to the write function to skip writing null fields (#213)
add tests
1 parent bc28d5b commit 79c834f

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,22 @@ def _transform_bwe(bwe, offset):
383383
}
384384

385385

386-
def _tabular_generator(tabular):
386+
def _tabular_generator(tabular, *, exclude_none=False):
387387
if isinstance(tabular, Table):
388388
for i in tabular.to_batches():
389389
for row in i.to_pylist():
390-
yield row
390+
if exclude_none:
391+
yield {k: v for k, v in row.items() if v is not None}
392+
else:
393+
yield row
391394
elif isinstance(tabular, pd.DataFrame):
392395
for row in tabular.to_dict("records"):
393-
yield row
396+
if exclude_none:
397+
yield {k: v for k, v in row.items() if not np.isnan(v)}
398+
else:
399+
yield row
394400
elif pl is not None and isinstance(tabular, pl.DataFrame):
395-
yield from _tabular_generator(tabular.to_arrow())
401+
yield from _tabular_generator(tabular.to_arrow(), exclude_none=exclude_none)
396402
elif isinstance(tabular, dict):
397403
iter_dict = {k: np.nditer(v) for k, v in tabular.items()}
398404
try:
@@ -414,13 +420,14 @@ def transform_python(self, _):
414420
return
415421

416422

417-
def write(collection, tabular):
423+
def write(collection, tabular, *, exclude_none: bool = False):
418424
"""Write data from `tabular` into the given MongoDB `collection`.
419425
420426
:Parameters:
421427
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
422428
against which to run the operation.
423429
- `tabular`: A tabular data store to use for the write operation.
430+
- `exclude_none`: Whether to skip writing `null` fields in documents.
424431
425432
:Returns:
426433
An instance of :class:`result.ArrowWriteResult`.
@@ -464,7 +471,7 @@ def write(collection, tabular):
464471
)
465472
raise ValueError(msg)
466473

467-
tabular_gen = _tabular_generator(tabular)
474+
tabular_gen = _tabular_generator(tabular, exclude_none=exclude_none)
468475

469476
# Handle Pandas NA objects.
470477
codec_options = collection.codec_options

bindings/python/test/test_arrow.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_find_with_session(self):
137137

138138
with self.client.start_session() as session:
139139
self.assertIsNone(session.operation_time)
140-
_ = session._server_session.last_use
140+
_ = getattr(session._server_session, "last_use", None)
141141
expected = Table.from_pydict(
142142
{"_id": [1, 2, 3, 4], "data": [10, 20, 30, None]},
143143
ArrowSchema([("_id", int32()), ("data", int64())]),
@@ -787,6 +787,27 @@ def test_binary_types(self):
787787
self.assertTrue(table_out_schema.schema == table_in.schema)
788788
self.assertTrue(table_out_none.equals(table_out_schema))
789789

790+
def test_exclude_none(self):
791+
schema = {"a": int32(), "b": int32()}
792+
b_data = [i for i in range(10)] * 2
793+
b_data[2] = None
794+
data = Table.from_pydict(
795+
{
796+
"a": [i for i in range(10)] * 2,
797+
"b": b_data,
798+
},
799+
ArrowSchema(schema),
800+
)
801+
self.coll.drop()
802+
write(self.coll, data)
803+
col_data = list(self.coll.find({}))
804+
assert "b" in col_data[2]
805+
806+
self.coll.drop()
807+
write(self.coll, data, exclude_none=True)
808+
col_data = list(self.coll.find({}))
809+
assert "b" not in col_data[2]
810+
790811

791812
class TestArrowExplicitApi(ArrowApiTestMixin, unittest.TestCase):
792813
def run_find(self, *args, **kwargs):

bindings/python/test/test_pandas.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,18 @@ def test_csv(self):
329329
out = pd.read_csv(f.name)
330330
self._assert_frames_equal(data, out)
331331

332+
def test_exclude_none(self):
333+
df = pd.DataFrame(data={"a": [1, 2, 3, 4], "b": [20, 40, 60, None]})
334+
self.coll.drop()
335+
write(self.coll, df)
336+
col_data = list(self.coll.find({}))
337+
assert "b" in col_data[3]
338+
339+
self.coll.drop()
340+
write(self.coll, df, exclude_none=True)
341+
col_data = list(self.coll.find({}))
342+
assert "b" not in col_data[3]
343+
332344

333345
class TestBSONTypes(PandasTestBase):
334346
@classmethod

0 commit comments

Comments
 (0)