Skip to content

Commit c313064

Browse files
committed
Add async test_grid_file
1 parent 66e104c commit c313064

File tree

6 files changed

+148
-59
lines changed

6 files changed

+148
-59
lines changed

gridfs/asynchronous/grid_file.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,8 @@ def __init__(
14541454
self._position = 0
14551455
self._file = file_document
14561456
self._session = session
1457+
if not _IS_SYNC:
1458+
self.closed = False
14571459

14581460
_id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.")
14591461
filename: str = _a_grid_out_property("filename", "Name of this file.")
@@ -1481,9 +1483,16 @@ def __init__(
14811483
_chunk_iter: Any
14821484

14831485
if not _IS_SYNC:
1486+
closed: bool
14841487

14851488
async def __anext__(self) -> bytes:
1486-
return await self.readline()
1489+
line = await self.readline()
1490+
if line:
1491+
return line
1492+
raise StopAsyncIteration()
1493+
1494+
async def to_list(self) -> list[bytes]:
1495+
return [x async for x in self] # noqa: C416, RUF100
14871496

14881497
async def open(self) -> None:
14891498
if not self._file:
@@ -1611,6 +1620,25 @@ async def readline(self, size: int = -1) -> bytes:
16111620
"""
16121621
return await self._read_size_or_line(size=size, line=True)
16131622

1623+
async def readlines(self, size: int = -1) -> list[bytes]:
1624+
"""Read one line or up to `size` bytes from the file.
1625+
1626+
:param size: the maximum number of bytes to read
1627+
"""
1628+
await self.open()
1629+
lines = []
1630+
remainder = int(self.length) - self._position
1631+
bytes_read = 0
1632+
while remainder > 0:
1633+
line = await self._read_size_or_line(line=True)
1634+
bytes_read += len(line)
1635+
lines.append(line)
1636+
remainder = int(self.length) - self._position
1637+
if 0 < size < bytes_read:
1638+
break
1639+
1640+
return lines
1641+
16141642
def tell(self) -> int:
16151643
"""Return the current position of this file."""
16161644
return self._position
@@ -1685,6 +1713,8 @@ async def close(self) -> None:
16851713
self._chunk_iter = None
16861714
if _IS_SYNC:
16871715
super().close()
1716+
else:
1717+
self.closed = True
16881718

16891719
def write(self, value: Any) -> NoReturn:
16901720
raise io.UnsupportedOperation("write")

gridfs/synchronous/grid_file.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,8 @@ def __init__(
14441444
self._position = 0
14451445
self._file = file_document
14461446
self._session = session
1447+
if not _IS_SYNC:
1448+
self.closed = False
14471449

14481450
_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
14491451
filename: str = _grid_out_property("filename", "Name of this file.")
@@ -1471,9 +1473,16 @@ def __init__(
14711473
_chunk_iter: Any
14721474

14731475
if not _IS_SYNC:
1476+
closed: bool
14741477

14751478
def __next__(self) -> bytes:
1476-
return self.readline()
1479+
line = self.readline()
1480+
if line:
1481+
return line
1482+
raise StopIteration()
1483+
1484+
def to_list(self) -> list[bytes]:
1485+
return [x for x in self] # noqa: C416, RUF100
14771486

14781487
def open(self) -> None:
14791488
if not self._file:
@@ -1601,6 +1610,25 @@ def readline(self, size: int = -1) -> bytes:
16011610
"""
16021611
return self._read_size_or_line(size=size, line=True)
16031612

1613+
def readlines(self, size: int = -1) -> list[bytes]:
1614+
"""Read one line or up to `size` bytes from the file.
1615+
1616+
:param size: the maximum number of bytes to read
1617+
"""
1618+
self.open()
1619+
lines = []
1620+
remainder = int(self.length) - self._position
1621+
bytes_read = 0
1622+
while remainder > 0:
1623+
line = self._read_size_or_line(line=True)
1624+
bytes_read += len(line)
1625+
lines.append(line)
1626+
remainder = int(self.length) - self._position
1627+
if 0 < size < bytes_read:
1628+
break
1629+
1630+
return lines
1631+
16041632
def tell(self) -> int:
16051633
"""Return the current position of this file."""
16061634
return self._position
@@ -1675,6 +1703,8 @@ def close(self) -> None:
16751703
self._chunk_iter = None
16761704
if _IS_SYNC:
16771705
super().close()
1706+
else:
1707+
self.closed = True
16781708

16791709
def write(self, value: Any) -> NoReturn:
16801710
raise io.UnsupportedOperation("write")

pymongo/synchronous/topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def _process_change(
521521
if server:
522522
server.pool.reset(interrupt_connections=interrupt_connections)
523523

524-
# Wake waiters in select_servers().
524+
# Wake witers in select_servers().
525525
self._condition.notify_all()
526526

527527
def on_change(

test/asynchronous/test_grid_file.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -584,33 +584,47 @@ async def test_readlines(self):
584584
self.assertEqual([b"Hope all is well.\n"], await g.readlines(17))
585585
self.assertEqual(b"Bye", await g.readline())
586586

587-
# async def test_iterator(self):
588-
# f = AsyncGridIn(self.db.fs)
589-
# await f.close()
590-
# g = AsyncGridOut(self.db.fs, f._id)
591-
# self.assertEqual([], list(g))
592-
#
593-
# f = AsyncGridIn(self.db.fs)
594-
# await f.write(b"hello world\nhere are\nsome lines.")
595-
# await f.close()
596-
# g = AsyncGridOut(self.db.fs, f._id)
597-
# self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
598-
# self.assertEqual(b"", await g.read(5))
599-
# self.assertEqual([], list(g))
600-
#
601-
# g = AsyncGridOut(self.db.fs, f._id)
602-
# self.assertEqual(b"hello world\n", next(iter(g)))
603-
# self.assertEqual(b"here", await g.read(4))
604-
# self.assertEqual(b" are\n", next(iter(g)))
605-
# self.assertEqual(b"some lines", await g.read(10))
606-
# self.assertEqual(b".", next(iter(g)))
607-
# self.assertRaises(StopIteration, iter(g).__next__)
608-
#
609-
# f = AsyncGridIn(self.db.fs, chunk_size=2)
610-
# await f.write(b"hello world")
611-
# await f.close()
612-
# g = AsyncGridOut(self.db.fs, f._id)
613-
# self.assertEqual([b"hello world"], list(g))
587+
async def test_iterator(self):
588+
f = AsyncGridIn(self.db.fs)
589+
await f.close()
590+
g = AsyncGridOut(self.db.fs, f._id)
591+
if _IS_SYNC:
592+
self.assertEqual([], list(g))
593+
else:
594+
self.assertEqual([], await g.to_list())
595+
596+
f = AsyncGridIn(self.db.fs)
597+
await f.write(b"hello world\nhere are\nsome lines.")
598+
await f.close()
599+
g = AsyncGridOut(self.db.fs, f._id)
600+
if _IS_SYNC:
601+
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
602+
else:
603+
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], await g.to_list())
604+
605+
self.assertEqual(b"", await g.read(5))
606+
if _IS_SYNC:
607+
self.assertEqual([], list(g))
608+
else:
609+
self.assertEqual([], await g.to_list())
610+
611+
g = AsyncGridOut(self.db.fs, f._id)
612+
self.assertEqual(b"hello world\n", await anext(aiter(g)))
613+
self.assertEqual(b"here", await g.read(4))
614+
self.assertEqual(b" are\n", await anext(aiter(g)))
615+
self.assertEqual(b"some lines", await g.read(10))
616+
self.assertEqual(b".", await anext(aiter(g)))
617+
with self.assertRaises(StopAsyncIteration):
618+
await aiter(g).__anext__()
619+
620+
f = AsyncGridIn(self.db.fs, chunk_size=2)
621+
await f.write(b"hello world")
622+
await f.close()
623+
g = AsyncGridOut(self.db.fs, f._id)
624+
if _IS_SYNC:
625+
self.assertEqual([b"hello world"], list(g))
626+
else:
627+
self.assertEqual([b"hello world"], await g.to_list())
614628

615629
async def test_read_unaligned_buffer_size(self):
616630
in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk."
@@ -811,7 +825,7 @@ async def test_survive_cursor_not_found(self):
811825
assert await client.address is not None
812826
await client._close_cursor_now(
813827
outfile._chunk_iter._cursor.cursor_id,
814-
_CursorAddress(await client.address, db.fs.chunks.full_name),
828+
_CursorAddress(await client.address, db.fs.chunks.full_name), # type: ignore[arg-type]
815829
)
816830

817831
# Read the rest of the file without error.

test/test_grid_file.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -582,33 +582,47 @@ def test_readlines(self):
582582
self.assertEqual([b"Hope all is well.\n"], g.readlines(17))
583583
self.assertEqual(b"Bye", g.readline())
584584

585-
# def test_iterator(self):
586-
# f = GridIn(self.db.fs)
587-
# f.close()
588-
# g = GridOut(self.db.fs, f._id)
589-
# self.assertEqual([], list(g))
590-
#
591-
# f = GridIn(self.db.fs)
592-
# f.write(b"hello world\nhere are\nsome lines.")
593-
# f.close()
594-
# g = GridOut(self.db.fs, f._id)
595-
# self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
596-
# self.assertEqual(b"", g.read(5))
597-
# self.assertEqual([], list(g))
598-
#
599-
# g = GridOut(self.db.fs, f._id)
600-
# self.assertEqual(b"hello world\n", next(iter(g)))
601-
# self.assertEqual(b"here", g.read(4))
602-
# self.assertEqual(b" are\n", next(iter(g)))
603-
# self.assertEqual(b"some lines", g.read(10))
604-
# self.assertEqual(b".", next(iter(g)))
605-
# self.assertRaises(StopIteration, iter(g).__next__)
606-
#
607-
# f = GridIn(self.db.fs, chunk_size=2)
608-
# f.write(b"hello world")
609-
# f.close()
610-
# g = GridOut(self.db.fs, f._id)
611-
# self.assertEqual([b"hello world"], list(g))
585+
def test_iterator(self):
586+
f = GridIn(self.db.fs)
587+
f.close()
588+
g = GridOut(self.db.fs, f._id)
589+
if _IS_SYNC:
590+
self.assertEqual([], list(g))
591+
else:
592+
self.assertEqual([], g.to_list())
593+
594+
f = GridIn(self.db.fs)
595+
f.write(b"hello world\nhere are\nsome lines.")
596+
f.close()
597+
g = GridOut(self.db.fs, f._id)
598+
if _IS_SYNC:
599+
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], list(g))
600+
else:
601+
self.assertEqual([b"hello world\n", b"here are\n", b"some lines."], g.to_list())
602+
603+
self.assertEqual(b"", g.read(5))
604+
if _IS_SYNC:
605+
self.assertEqual([], list(g))
606+
else:
607+
self.assertEqual([], g.to_list())
608+
609+
g = GridOut(self.db.fs, f._id)
610+
self.assertEqual(b"hello world\n", next(iter(g)))
611+
self.assertEqual(b"here", g.read(4))
612+
self.assertEqual(b" are\n", next(iter(g)))
613+
self.assertEqual(b"some lines", g.read(10))
614+
self.assertEqual(b".", next(iter(g)))
615+
with self.assertRaises(StopIteration):
616+
iter(g).__next__()
617+
618+
f = GridIn(self.db.fs, chunk_size=2)
619+
f.write(b"hello world")
620+
f.close()
621+
g = GridOut(self.db.fs, f._id)
622+
if _IS_SYNC:
623+
self.assertEqual([b"hello world"], list(g))
624+
else:
625+
self.assertEqual([b"hello world"], g.to_list())
612626

613627
def test_read_unaligned_buffer_size(self):
614628
in_data = b"This is a text that doesn't quite fit in a single 16-byte chunk."
@@ -809,7 +823,7 @@ def test_survive_cursor_not_found(self):
809823
assert client.address is not None
810824
client._close_cursor_now(
811825
outfile._chunk_iter._cursor.cursor_id,
812-
_CursorAddress(client.address, db.fs.chunks.full_name),
826+
_CursorAddress(client.address, db.fs.chunks.full_name), # type: ignore[arg-type]
813827
)
814828

815829
# Read the rest of the file without error.

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"asynchronous": "synchronous",
4848
"Asynchronous": "Synchronous",
4949
"anext": "next",
50+
"aiter": "iter",
5051
"_ALock": "_Lock",
5152
"_ACondition": "_Condition",
5253
"AsyncGridFS": "GridFS",

0 commit comments

Comments
 (0)