Skip to content

Commit 5a49ccc

Browse files
authored
PYTHON-4590 - Add type guards to async API methods (mongodb#1820)
1 parent 5a70039 commit 5a49ccc

File tree

9 files changed

+42
-16
lines changed

9 files changed

+42
-16
lines changed

pymongo/asynchronous/collection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ def __init__(
228228
)
229229
if not isinstance(name, str):
230230
raise TypeError("name must be an instance of str")
231+
from pymongo.asynchronous.database import AsyncDatabase
232+
233+
if not isinstance(database, AsyncDatabase):
234+
raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given")
231235

232236
if not name or ".." in name:
233237
raise InvalidName("collection names cannot be empty")

pymongo/asynchronous/database.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,14 @@ def __init__(
119119
read_concern or client.read_concern,
120120
)
121121

122+
from pymongo.asynchronous.mongo_client import AsyncMongoClient
123+
122124
if not isinstance(name, str):
123125
raise TypeError("name must be an instance of str")
124126

127+
if not isinstance(client, AsyncMongoClient):
128+
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
129+
125130
if name != "$external":
126131
_check_name(name)
127132

pymongo/asynchronous/encryption.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
194194
# Wrap I/O errors in PyMongo exceptions.
195195
_raise_connection_failure((host, port), error)
196196

197-
async def collection_info(
198-
self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes
199-
) -> Optional[bytes]:
197+
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
200198
"""Get the collection info for a namespace.
201199
202200
The returned collection info is passed to libmongocrypt which reads
@@ -598,6 +596,9 @@ def __init__(
598596
if not isinstance(codec_options, CodecOptions):
599597
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
600598

599+
if not isinstance(key_vault_client, AsyncMongoClient):
600+
raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}")
601+
601602
self._kms_providers = kms_providers
602603
self._key_vault_namespace = key_vault_namespace
603604
self._key_vault_client = key_vault_client
@@ -683,6 +684,11 @@ async def create_encrypted_collection(
683684
https://mongodb.com/docs/manual/reference/command/create
684685
685686
"""
687+
if not isinstance(database, AsyncDatabase):
688+
raise TypeError(
689+
f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given"
690+
)
691+
686692
encrypted_fields = deepcopy(encrypted_fields)
687693
for i, field in enumerate(encrypted_fields["fields"]):
688694
if isinstance(field, dict) and field.get("keyId") is None:

pymongo/asynchronous/mongo_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,9 @@ class _MongoClientErrorHandler:
24192419
def __init__(
24202420
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
24212421
):
2422+
if not isinstance(client, AsyncMongoClient):
2423+
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
2424+
24222425
self.client = client
24232426
self.server_address = server.description.address
24242427
self.session = session

pymongo/synchronous/collection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def __init__(
231231
)
232232
if not isinstance(name, str):
233233
raise TypeError("name must be an instance of str")
234+
from pymongo.synchronous.database import Database
235+
236+
if not isinstance(database, Database):
237+
raise TypeError(f"Collection requires a Database but {type(database)} given")
234238

235239
if not name or ".." in name:
236240
raise InvalidName("collection names cannot be empty")

pymongo/synchronous/database.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,14 @@ def __init__(
119119
read_concern or client.read_concern,
120120
)
121121

122+
from pymongo.synchronous.mongo_client import MongoClient
123+
122124
if not isinstance(name, str):
123125
raise TypeError("name must be an instance of str")
124126

127+
if not isinstance(client, MongoClient):
128+
raise TypeError(f"MongoClient required but given {type(client)}")
129+
125130
if name != "$external":
126131
_check_name(name)
127132

pymongo/synchronous/encryption.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
194194
# Wrap I/O errors in PyMongo exceptions.
195195
_raise_connection_failure((host, port), error)
196196

197-
def collection_info(
198-
self, database: Database[Mapping[str, Any]], filter: bytes
199-
) -> Optional[bytes]:
197+
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
200198
"""Get the collection info for a namespace.
201199
202200
The returned collection info is passed to libmongocrypt which reads
@@ -596,6 +594,9 @@ def __init__(
596594
if not isinstance(codec_options, CodecOptions):
597595
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
598596

597+
if not isinstance(key_vault_client, MongoClient):
598+
raise TypeError(f"MongoClient required but given {type(key_vault_client)}")
599+
599600
self._kms_providers = kms_providers
600601
self._key_vault_namespace = key_vault_namespace
601602
self._key_vault_client = key_vault_client
@@ -681,6 +682,11 @@ def create_encrypted_collection(
681682
https://mongodb.com/docs/manual/reference/command/create
682683
683684
"""
685+
if not isinstance(database, Database):
686+
raise TypeError(
687+
f"create_encrypted_collection() requires a Database but {type(database)} given"
688+
)
689+
684690
encrypted_fields = deepcopy(encrypted_fields)
685691
for i, field in enumerate(encrypted_fields["fields"]):
686692
if isinstance(field, dict) and field.get("keyId") is None:

pymongo/synchronous/mongo_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler:
24062406
)
24072407

24082408
def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
2409+
if not isinstance(client, MongoClient):
2410+
raise TypeError(f"MongoClient required but given {type(client)}")
2411+
24092412
self.client = client
24102413
self.server_address = server.description.address
24112414
self.session = session

test/helpers.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,13 @@
3535
HAVE_IPADDRESS = True
3636
except ImportError:
3737
HAVE_IPADDRESS = False
38-
from contextlib import contextmanager
3938
from functools import wraps
40-
from test.version import Version
4139
from typing import Any, Callable, Dict, Generator, no_type_check
4240
from unittest import SkipTest
43-
from urllib.parse import quote_plus
4441

45-
import pymongo
46-
import pymongo.errors
4742
from bson.son import SON
4843
from pymongo import common, message
49-
from pymongo.common import partition_node
50-
from pymongo.hello import HelloCompat
51-
from pymongo.server_api import ServerApi
5244
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
53-
from pymongo.synchronous.database import Database
54-
from pymongo.synchronous.mongo_client import MongoClient
5545
from pymongo.uri_parser import parse_uri
5646

5747
if HAVE_SSL:

0 commit comments

Comments
 (0)