Skip to content

PYTHON-4669 - Update Async GridFS APIs for Motor Compatibility #1821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 4, 2024
73 changes: 46 additions & 27 deletions gridfs/asynchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,19 +1194,9 @@ def __setattr__(self, name: str, value: Any) -> None:
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having __setattr__ always raise an error in async, we should allow it as long as the file is not closed and only raise an error if the file is closed:

                # All other attributes are part of the document in db.fs.files.
                # Store them to be sent to server on close() or if closed, send
                # them now.
                self._file[name] = value
                if self._closed:        
                    if _IS_SYNC:
                        self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
                    else:
                        raise AttributeError(
                            "AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead")


async def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
await self._coll.files.update_one(
{"_id": self._file["_id"]}, {"$set": {name: value}}
)
self._file[name] = value
if self._closed:
await self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})

async def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
Expand Down Expand Up @@ -1400,7 +1390,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
return False


class AsyncGridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any


class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore

"""Class to read data out of GridFS."""

def __init__(
Expand Down Expand Up @@ -1460,6 +1454,8 @@ def __init__(
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False

_id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _a_grid_out_property("filename", "Name of this file.")
Expand All @@ -1486,16 +1482,17 @@ def __init__(
_file: Any
_chunk_iter: Any

async def __anext__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool

def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError(
"AsyncGridOut does not support synchronous iteration. Use `async for` instead"
)
async def __anext__(self) -> bytes:
line = await self.readline()
if line:
return line
raise StopAsyncIteration()

async def to_list(self) -> list[bytes]:
return [x async for x in self] # noqa: C416, RUF100

async def open(self) -> None:
if not self._file:
Expand Down Expand Up @@ -1616,18 +1613,37 @@ async def read(self, size: int = -1) -> bytes:
"""
return await self._read_size_or_line(size=size)

async def readline(self, size: int = -1) -> bytes: # type: ignore[override]
async def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)

async def readlines(self, size: int = -1) -> list[bytes]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be async only too since we intentionally get it for free from IOBase in the sync version.

"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
await self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = await self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break

return lines

def tell(self) -> int:
"""Return the current position of this file."""
return self._position

async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.

:param pos: the position (or offset if using relative
Expand Down Expand Up @@ -1690,12 +1706,15 @@ def __aiter__(self) -> AsyncGridOut:
"""
return self

async def close(self) -> None: # type: ignore[override]
async def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
await self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True

def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
Expand Down
69 changes: 46 additions & 23 deletions gridfs/synchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,17 +1184,9 @@ def __setattr__(self, name: str, value: Any) -> None:
)

def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})

def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
Expand Down Expand Up @@ -1388,7 +1380,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
return False


class GridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any


class GridOut(GRIDOUT_BASE_CLASS): # type: ignore

"""Class to read data out of GridFS."""

def __init__(
Expand Down Expand Up @@ -1448,6 +1444,8 @@ def __init__(
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False

_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _grid_out_property("filename", "Name of this file.")
Expand All @@ -1474,14 +1472,17 @@ def __init__(
_file: Any
_chunk_iter: Any

def __next__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool

def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError("GridOut does not support synchronous iteration. Use `for` instead")
def __next__(self) -> bytes:
line = self.readline()
if line:
return line
raise StopIteration()

def to_list(self) -> list[bytes]:
return [x for x in self] # noqa: C416, RUF100

def open(self) -> None:
if not self._file:
Expand Down Expand Up @@ -1602,18 +1603,37 @@ def read(self, size: int = -1) -> bytes:
"""
return self._read_size_or_line(size=size)

def readline(self, size: int = -1) -> bytes: # type: ignore[override]
def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)

def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.

:param size: the maximum number of bytes to read
"""
self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break

return lines

def tell(self) -> int:
"""Return the current position of this file."""
return self._position

def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.

:param pos: the position (or offset if using relative
Expand Down Expand Up @@ -1676,12 +1696,15 @@ def __iter__(self) -> GridOut:
"""
return self

def close(self) -> None: # type: ignore[override]
def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True

def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ async def inner(*args: Any, **kwargs: Any) -> Any:

if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
else:

async def anext(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return await cls.__anext__()

def aiter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return cls.__aiter__()
5 changes: 5 additions & 0 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def inner(*args: Any, **kwargs: Any) -> Any:

if sys.version_info >= (3, 10):
next = builtins.next
iter = builtins.iter
else:

def next(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__next__()

def iter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__iter__()
2 changes: 1 addition & 1 deletion pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _process_change(
if server:
server.pool.reset(interrupt_connections=interrupt_connections)

# Wake waiters in select_servers().
# Wake witers in select_servers().
self._condition.notify_all()

def on_change(
Expand Down
4 changes: 2 additions & 2 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,11 +947,11 @@ def tearDownClass(cls):

@classmethod
def _setup_class(cls):
cls._setup_class()
pass

@classmethod
def _tearDown_class(cls):
cls._tearDown_class()
pass


class IntegrationTest(PyMongoTestCase):
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,11 +949,11 @@ def tearDownClass(cls):

@classmethod
async def _setup_class(cls):
await cls._setup_class()
pass

@classmethod
async def _tearDown_class(cls):
await cls._tearDown_class()
pass


class AsyncIntegrationTest(AsyncPyMongoTestCase):
Expand Down
Loading
Loading