Skip to content

PYTHON-4590 - Add type guards to async API methods #1820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def __init__(
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
from pymongo.asynchronous.database import AsyncDatabase

if not isinstance(database, AsyncDatabase):
raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given")

if not name or ".." in name:
raise InvalidName("collection names cannot be empty")
Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,14 @@ def __init__(
read_concern or client.read_concern,
)

from pymongo.asynchronous.mongo_client import AsyncMongoClient

if not isinstance(name, str):
raise TypeError("name must be an instance of str")

if not isinstance(client, AsyncMongoClient):
raise TypeError(f"AsyncMongoClient required but given {type(client)}")

if name != "$external":
_check_name(name)

Expand Down
12 changes: 9 additions & 3 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

async def collection_info(
self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes
) -> Optional[bytes]:
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand Down Expand Up @@ -598,6 +596,9 @@ def __init__(
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")

if not isinstance(key_vault_client, AsyncMongoClient):
raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}")

self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
Expand Down Expand Up @@ -683,6 +684,11 @@ async def create_encrypted_collection(
https://mongodb.com/docs/manual/reference/command/create

"""
if not isinstance(database, AsyncDatabase):
raise TypeError(
f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given"
)

encrypted_fields = deepcopy(encrypted_fields)
for i, field in enumerate(encrypted_fields["fields"]):
if isinstance(field, dict) and field.get("keyId") is None:
Expand Down
3 changes: 3 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,9 @@ class _MongoClientErrorHandler:
def __init__(
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
):
if not isinstance(client, AsyncMongoClient):
raise TypeError(f"AsyncMongoClient required but given {type(client)}")

self.client = client
self.server_address = server.description.address
self.session = session
Expand Down
4 changes: 4 additions & 0 deletions pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def __init__(
)
if not isinstance(name, str):
raise TypeError("name must be an instance of str")
from pymongo.synchronous.database import Database

if not isinstance(database, Database):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm paranoid this might break apps that Mock pymongo classes but it should be fine. Anything that's mocking should inherit from our base classes anyway (instead of trying to ducktype).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we'll have to watch out for bug reports with this.

raise TypeError(f"Collection requires a Database but {type(database)} given")

if not name or ".." in name:
raise InvalidName("collection names cannot be empty")
Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,14 @@ def __init__(
read_concern or client.read_concern,
)

from pymongo.synchronous.mongo_client import MongoClient

if not isinstance(name, str):
raise TypeError("name must be an instance of str")

if not isinstance(client, MongoClient):
raise TypeError(f"MongoClient required but given {type(client)}")

if name != "$external":
_check_name(name)

Expand Down
12 changes: 9 additions & 3 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

def collection_info(
self, database: Database[Mapping[str, Any]], filter: bytes
) -> Optional[bytes]:
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand Down Expand Up @@ -596,6 +594,9 @@ def __init__(
if not isinstance(codec_options, CodecOptions):
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")

if not isinstance(key_vault_client, MongoClient):
raise TypeError(f"MongoClient required but given {type(key_vault_client)}")

self._kms_providers = kms_providers
self._key_vault_namespace = key_vault_namespace
self._key_vault_client = key_vault_client
Expand Down Expand Up @@ -681,6 +682,11 @@ def create_encrypted_collection(
https://mongodb.com/docs/manual/reference/command/create

"""
if not isinstance(database, Database):
raise TypeError(
f"create_encrypted_collection() requires a Database but {type(database)} given"
)

encrypted_fields = deepcopy(encrypted_fields)
for i, field in enumerate(encrypted_fields["fields"]):
if isinstance(field, dict) and field.get("keyId") is None:
Expand Down
3 changes: 3 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler:
)

def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
if not isinstance(client, MongoClient):
raise TypeError(f"MongoClient required but given {type(client)}")

self.client = client
self.server_address = server.description.address
self.session = session
Expand Down
10 changes: 0 additions & 10 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,13 @@
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, no_type_check
from unittest import SkipTest
from urllib.parse import quote_plus

import pymongo
import pymongo.errors
from bson.son import SON
from pymongo import common, message
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.uri_parser import parse_uri

if HAVE_SSL:
Expand Down
Loading