Skip to content

PYTHON-4590 - Add type guards to async API methods #1820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ async def write_command(
client: AsyncMongoClient,
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""

cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
Expand Down Expand Up @@ -324,6 +325,7 @@ async def unack_write(
client: AsyncMongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for AsyncConnection.unack_write that handles event publishing."""

if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
Expand Down
6 changes: 6 additions & 0 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def __init__(
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
from pymongo.asynchronous.database import AsyncDatabase

if not isinstance(database, AsyncDatabase):
raise TypeError(
f"AsyncCollection requires an AsyncDatabase, {database} is an instance of {type(database)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make the same {client} error message change to all the Database ones as well since Database's repr includes the client.

)

if not name or ".." in name:
raise InvalidName("collection names cannot be empty")
Expand Down
7 changes: 7 additions & 0 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,16 @@ def __init__(
read_concern or client.read_concern,
)

from pymongo.asynchronous.mongo_client import AsyncMongoClient

if not isinstance(name, str):
raise TypeError("name must be an instance of str")

if not isinstance(client, AsyncMongoClient):
raise TypeError(
f"AsyncMongoClient required but {client} is an instance of {type(client)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to include both the repr(client) and the type(client)? It could be safer to only add the type to avoid adding topology/hostname info into these errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I think type(client) makes more sense in this instance.

)

if name != "$external":
_check_name(name)

Expand Down
15 changes: 12 additions & 3 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

async def collection_info(
self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes
) -> Optional[bytes]:
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand Down Expand Up @@ -339,6 +337,7 @@ def __init__(self, client: AsyncMongoClient[_DocumentTypeArg], opts: AutoEncrypt
:param client: The encrypted AsyncMongoClient.
:param opts: The encrypted client's :class:`AutoEncryptionOpts`.
"""

if opts._schema_map is None:
schema_map = None
else:
Expand Down Expand Up @@ -598,6 +597,11 @@ def __init__(
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")

if not isinstance(key_vault_client, AsyncMongoClient):
raise TypeError(
f"AsyncMongoClient required but {key_vault_client} is an instance of {type(key_vault_client)}"
)

self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
Expand Down Expand Up @@ -683,6 +687,11 @@ async def create_encrypted_collection(
https://mongodb.com/docs/manual/reference/command/create

"""
if not isinstance(database, AsyncDatabase):
raise TypeError(
f"create_encrypted_collection() requires an AsyncDatabase, but {database} is an instance of {type(database)}"
)

encrypted_fields = deepcopy(encrypted_fields)
for i, field in enumerate(encrypted_fields["fields"]):
if isinstance(field, dict) and field.get("keyId") is None:
Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,11 @@ class _MongoClientErrorHandler:
def __init__(
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
):
if not isinstance(client, AsyncMongoClient):
raise TypeError(
f"AsyncMongoClient required but {client} is an instance of {type(client)}"
)

self.client = client
self.server_address = server.description.address
self.session = session
Expand Down
2 changes: 2 additions & 0 deletions pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ async def run_operation(
:param unpack_res: A callable that decodes the wire protocol response.
:param client: An AsyncMongoClient instance.
"""

assert listeners is not None

publish = listeners.enabled_for_commands
start = datetime.now()

Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def write_command(
client: MongoClient,
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""

cmd[bwc.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
Expand Down Expand Up @@ -324,6 +325,7 @@ def unack_write(
client: MongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""

if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
Expand Down
6 changes: 6 additions & 0 deletions pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def __init__(
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
from pymongo.synchronous.database import Database

if not isinstance(database, Database):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm paranoid this might break apps that Mock pymongo classes but it should be fine. Anything that's mocking should inherit from our base classes anyway (instead of trying to ducktype).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we'll have to watch out for bug reports with this.

raise TypeError(
f"Collection requires a Database, {database} is an instance of {type(database)}"
)

if not name or ".." in name:
raise InvalidName("collection names cannot be empty")
Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,14 @@ def __init__(
read_concern or client.read_concern,
)

from pymongo.synchronous.mongo_client import MongoClient

if not isinstance(name, str):
raise TypeError("name must be an instance of str")

if not isinstance(client, MongoClient):
raise TypeError(f"MongoClient required but {client} is an instance of {type(client)}")

if name != "$external":
_check_name(name)

Expand Down
15 changes: 12 additions & 3 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

def collection_info(
self, database: Database[Mapping[str, Any]], filter: bytes
) -> Optional[bytes]:
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand Down Expand Up @@ -337,6 +335,7 @@ def __init__(self, client: MongoClient[_DocumentTypeArg], opts: AutoEncryptionOp
:param client: The encrypted MongoClient.
:param opts: The encrypted client's :class:`AutoEncryptionOpts`.
"""

if opts._schema_map is None:
schema_map = None
else:
Expand Down Expand Up @@ -596,6 +595,11 @@ def __init__(
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")

if not isinstance(key_vault_client, MongoClient):
raise TypeError(
f"MongoClient required but {key_vault_client} is an instance of {type(key_vault_client)}"
)

self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
Expand Down Expand Up @@ -681,6 +685,11 @@ def create_encrypted_collection(
https://mongodb.com/docs/manual/reference/command/create

"""
if not isinstance(database, Database):
raise TypeError(
f"create_encrypted_collection() requires a Database, but {database} is an instance of {type(database)}"
)

encrypted_fields = deepcopy(encrypted_fields)
for i, field in enumerate(encrypted_fields["fields"]):
if isinstance(field, dict) and field.get("keyId") is None:
Expand Down
3 changes: 3 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler:
)

def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
if not isinstance(client, MongoClient):
raise TypeError(f"MongoClient required but {client} is an instance of {type(client)}")

self.client = client
self.server_address = server.description.address
self.session = session
Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def run_operation(
:param unpack_res: A callable that decodes the wire protocol response.
:param client: A MongoClient instance.
"""

assert listeners is not None

publish = listeners.enabled_for_commands
start = datetime.now()

Expand Down
10 changes: 0 additions & 10 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,13 @@
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, no_type_check
from unittest import SkipTest
from urllib.parse import quote_plus

import pymongo
import pymongo.errors
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.uri_parser import parse_uri

if HAVE_SSL:
Expand Down
Loading