Skip to content

Commit 8d24dd9

Browse files
committed
clean up reauth handling
1 parent b13427e commit 8d24dd9

File tree

8 files changed

+209
-49
lines changed

8 files changed

+209
-49
lines changed

pymongo/auth.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,6 @@ def authenticate(credentials, sock_info, reauthenticate=False):
589589
"""Authenticate sock_info."""
590590
mechanism = credentials.mechanism
591591
auth_func = _AUTH_MAP[mechanism]
592-
if reauthenticate:
593-
sock_info.handle_reauthenticate()
594592
if mechanism == "MONGODB-OIDC":
595593
_authenticate_oidc(credentials, sock_info, reauthenticate)
596594
else:

pymongo/auth_oidc.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ class _OIDCAuthenticator:
100100

101101
def get_current_token(self, use_callbacks=True):
102102
properties = self.properties
103-
principal_name = self.username
104103

105104
request_cb = properties.request_token_callback
106105
refresh_cb = properties.refresh_token_callback
@@ -238,29 +237,26 @@ def clear(self):
238237

239238
def run_command(self, sock_info, cmd):
240239
try:
241-
return sock_info.command("$external", cmd)
240+
return sock_info.command("$external", cmd, no_reauth=True)
242241
except OperationFailure as exc:
243242
self.clear()
244243
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
245244
if "jwt" in bson.decode(cmd["payload"]):
246245
if self.idp_info_gen_id > self.reauth_gen_id:
247246
raise
248-
self.handle_reauth(sock_info)
249-
return self.authenticate(sock_info)
247+
return self.authenticate(sock_info, reauthenticate=True)
250248
raise
251249

252-
def handle_reauth(self, sock_info):
253-
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
254-
if prev_id != self.token_gen_id:
255-
# No need to preemptively clear, we've already changed tokens.
256-
return
257-
258-
self.reauth_gen_id = self.idp_info_gen_id
259-
self.token_exp_utc = None
260-
if not self.properties.refresh_token_callback:
261-
self.clear()
250+
def authenticate(self, sock_info, reauthenticate=False):
251+
if reauthenticate:
252+
prev_id = getattr(sock_info, "oidc_token_gen_id", None)
253+
# Check if we've already changed tokens.
254+
if prev_id == self.token_gen_id:
255+
self.reauth_gen_id = self.idp_info_gen_id
256+
self.token_exp_utc = None
257+
if not self.properties.refresh_token_callback:
258+
self.clear()
262259

263-
def authenticate(self, sock_info):
264260
ctx = sock_info.auth_ctx
265261
cmd = None
266262

@@ -300,6 +296,4 @@ def authenticate(self, sock_info):
300296
def _authenticate_oidc(credentials, sock_info, reauthenticate):
301297
"""Authenticate using MONGODB-OIDC."""
302298
authenticator = _get_authenticator(credentials, sock_info.address)
303-
if reauthenticate:
304-
authenticator.handle_reauth(sock_info)
305-
return authenticator.authenticate(sock_info)
299+
return authenticator.authenticate(sock_info, reauthenticate=reauthenticate)

pymongo/helpers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,35 @@ def _handle_exception():
270270
pass
271271
finally:
272272
del einfo
273+
274+
275+
def _handle_reauth(func):
276+
def inner(*args, **kwargs):
277+
no_reauth = kwargs.pop("no_reauth", False)
278+
from pymongo.pool import SocketInfo
279+
280+
try:
281+
return func(*args, **kwargs)
282+
except OperationFailure as exc:
283+
if no_reauth:
284+
raise
285+
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
286+
# Look for an argument that either is a SocketInfo
287+
# or has a socket_info attribute, so we can trigger
288+
# a reauth.
289+
sock_info = None
290+
for arg in args:
291+
if isinstance(arg, SocketInfo):
292+
sock_info = arg
293+
break
294+
if hasattr(arg, "sock_info"):
295+
sock_info = arg.sock_info
296+
break
297+
if sock_info:
298+
sock_info.authenticate(reauthenticate=True)
299+
else:
300+
raise
301+
return func(*args, **kwargs)
302+
raise
303+
304+
return inner

pymongo/message.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
ProtocolError,
5555
)
5656
from pymongo.hello import HelloCompat
57+
from pymongo.helpers import _handle_reauth
5758
from pymongo.read_preferences import ReadPreference
5859
from pymongo.write_concern import WriteConcern
5960

@@ -909,6 +910,7 @@ def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
909910
self.start_time = datetime.datetime.now()
910911
return result
911912

913+
@_handle_reauth
912914
def write_command(self, cmd, request_id, msg, docs):
913915
"""A proxy for SocketInfo.write_command that handles event publishing."""
914916
if self.publish:

pymongo/mongo_client.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,14 +1397,7 @@ def is_retrying():
13971397
assert last_error is not None
13981398
raise last_error
13991399
retryable = False
1400-
# Handle re-authentication.
1401-
try:
1402-
return func(session, sock_info, retryable)
1403-
except OperationFailure as exc:
1404-
if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE:
1405-
sock_info.authenticate(reauthenticate=True)
1406-
return func(session, sock_info, retryable)
1407-
raise
1400+
return func(session, sock_info, retryable)
14081401
except ServerSelectionTimeoutError:
14091402
if is_retrying():
14101403
# The application may think the write was never attempted
@@ -1468,14 +1461,7 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True
14681461
# not support retryable reads, raise the last error.
14691462
assert last_error is not None
14701463
raise last_error
1471-
# Handle re-authentication.
1472-
try:
1473-
return func(session, server, sock_info, read_pref)
1474-
except OperationFailure as exc:
1475-
if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE:
1476-
sock_info.authenticate(reauthenticate=True)
1477-
return func(session, server, sock_info, read_pref)
1478-
raise
1464+
return func(session, server, sock_info, read_pref)
14791465
except ServerSelectionTimeoutError:
14801466
if retrying:
14811467
# The application may think the write was never attempted

pymongo/pool.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_CertificateError,
5757
)
5858
from pymongo.hello import Hello, HelloCompat
59+
from pymongo.helpers import _handle_reauth
5960
from pymongo.lock import _create_lock
6061
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
6162
from pymongo.network import command, receive_message
@@ -704,6 +705,7 @@ def _next_reply(self):
704705
helpers._check_command_response(response_doc, self.max_wire_version)
705706
return response_doc
706707

708+
@_handle_reauth
707709
def command(
708710
self,
709711
dbname,
@@ -788,7 +790,7 @@ def command(
788790
exhaust_allowed=exhaust_allowed,
789791
write_concern=write_concern,
790792
)
791-
except (OperationFailure, NotPrimaryError):
793+
except (OperationFailure, NotPrimaryError) as exc:
792794
raise
793795
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
794796
except BaseException as error:
@@ -864,7 +866,12 @@ def authenticate(self, reauthenticate=False):
864866
"""
865867
# CMAP spec says to publish the ready event only after authenticating
866868
# the connection.
867-
if not self.ready or reauthenticate:
869+
if reauthenticate:
870+
if self.performed_handshake:
871+
# Existing auth_ctx is stale, remove it.
872+
self.auth_ctx = None
873+
self.ready = False
874+
if not self.ready:
868875
creds = self.opts._credentials
869876
if creds:
870877
auth.authenticate(creds, self, reauthenticate=reauthenticate)
@@ -927,12 +934,6 @@ def idle_time_seconds(self):
927934
"""Seconds since this socket was last checked into its pool."""
928935
return time.monotonic() - self.last_checkin_time
929936

930-
def handle_reauthenticate(self):
931-
"""Handle a reauthentication."""
932-
if self.performed_handshake:
933-
# Existing auth_ctx is stale, remove it.
934-
self.auth_ctx = None
935-
936937
def _raise_connection_failure(self, error):
937938
# Catch *all* exceptions from socket methods and close the socket. In
938939
# regular Python, socket operations only raise socket.error, even if

pymongo/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from bson import _decode_all_selective
2020
from pymongo.errors import NotPrimaryError, OperationFailure
21-
from pymongo.helpers import _check_command_response
21+
from pymongo.helpers import _check_command_response, _handle_reauth
2222
from pymongo.message import _convert_exception, _OpMsg
2323
from pymongo.response import PinnedResponse, Response
2424

@@ -73,6 +73,7 @@ def request_check(self):
7373
"""Check the server's state soon."""
7474
self._monitor.request_check()
7575

76+
@_handle_reauth
7677
def run_operation(self, sock_info, operation, read_preference, listeners, unpack_res):
7778
"""Run a _Query or _GetMore operation and return a Response object.
7879

test/auth_aws/test_auth_oidc.py

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828

2929
from bson import SON
3030
from pymongo import MongoClient
31-
from pymongo.auth import MongoCredential
3231
from pymongo.auth_oidc import _CACHE as _oidc_cache
32+
from pymongo.cursor import CursorType
3333
from pymongo.errors import ConfigurationError, OperationFailure
34+
from pymongo.operations import InsertOne
3435

3536

3637
class TestAuthOIDC(unittest.TestCase):
@@ -496,8 +497,8 @@ def test_reauthenticate_succeeds(self):
496497

497498
with self.fail_point(
498499
{
499-
"mode": {"times": 2},
500-
"data": {"failCommands": ["find", "saslStart"], "errorCode": 391},
500+
"mode": {"times": 1},
501+
"data": {"failCommands": ["find"], "errorCode": 391},
501502
}
502503
):
503504
# Perform a find operation.
@@ -529,7 +530,152 @@ def test_reauthenticate_succeeds(self):
529530
self.assertEqual(self.refresh_called, 1)
530531
client.close()
531532

532-
def test_reauthenticate_retries_and_succees_with_cache(self):
533+
def test_reauthenticate_succeeds_bulk_write(self):
534+
request_cb = self.create_request_cb()
535+
refresh_cb = self.create_refresh_cb()
536+
537+
# Create a client with the callbacks.
538+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
539+
client = MongoClient(self.uri_single, authmechanismproperties=props)
540+
541+
# Perform a find operation.
542+
client.test.test.find_one()
543+
544+
# Assert that the refresh callback has not been called.
545+
self.assertEqual(self.refresh_called, 0)
546+
547+
with self.fail_point(
548+
{
549+
"mode": {"times": 1},
550+
"data": {"failCommands": ["insert"], "errorCode": 391},
551+
}
552+
):
553+
# Perform a bulk write operation.
554+
client.test.test.bulk_write([InsertOne({})])
555+
556+
# Assert that the refresh callback has been called.
557+
self.assertEqual(self.refresh_called, 1)
558+
client.close()
559+
560+
def test_reauthenticate_succeeds_bulk_read(self):
561+
request_cb = self.create_request_cb()
562+
refresh_cb = self.create_refresh_cb()
563+
564+
# Create a client with the callbacks.
565+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
566+
client = MongoClient(self.uri_single, authmechanismproperties=props)
567+
568+
# Perform a find operation.
569+
client.test.test.find_one()
570+
571+
# Perform a bulk write operation.
572+
client.test.test.bulk_write([InsertOne({})])
573+
574+
# Assert that the refresh callback has not been called.
575+
self.assertEqual(self.refresh_called, 0)
576+
577+
with self.fail_point(
578+
{
579+
"mode": {"times": 1},
580+
"data": {"failCommands": ["find"], "errorCode": 391},
581+
}
582+
):
583+
# Perform a bulk read operation.
584+
cursor = client.test.test.find_raw_batches({})
585+
list(cursor)
586+
587+
# Assert that the refresh callback has been called.
588+
self.assertEqual(self.refresh_called, 1)
589+
client.close()
590+
591+
def test_reauthenticate_succeeds_cursor(self):
592+
request_cb = self.create_request_cb()
593+
refresh_cb = self.create_refresh_cb()
594+
595+
# Create a client with the callbacks.
596+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
597+
client = MongoClient(self.uri_single, authmechanismproperties=props)
598+
599+
# Perform an insert operation.
600+
client.test.test.insert_one({"a": 1})
601+
602+
# Assert that the refresh callback has not been called.
603+
self.assertEqual(self.refresh_called, 0)
604+
605+
with self.fail_point(
606+
{
607+
"mode": {"times": 1},
608+
"data": {"failCommands": ["find"], "errorCode": 391},
609+
}
610+
):
611+
# Perform a find operation.
612+
cursor = client.test.test.find({"a": 1})
613+
self.assertGreaterEqual(len(list(cursor)), 1)
614+
615+
# Assert that the refresh callback has been called.
616+
self.assertEqual(self.refresh_called, 1)
617+
client.close()
618+
619+
def test_reauthenticate_succeeds_get_more(self):
620+
request_cb = self.create_request_cb()
621+
refresh_cb = self.create_refresh_cb()
622+
623+
# Create a client with the callbacks.
624+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
625+
client = MongoClient(self.uri_single, authmechanismproperties=props)
626+
627+
# Perform an insert operation.
628+
client.test.test.insert_many([{"a": 1}, {"a": 1}])
629+
630+
# Assert that the refresh callback has not been called.
631+
self.assertEqual(self.refresh_called, 0)
632+
633+
with self.fail_point(
634+
{
635+
"mode": {"times": 1},
636+
"data": {"failCommands": ["find"], "errorCode": 391},
637+
}
638+
):
639+
# Perform a find operation.
640+
cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST)
641+
self.assertGreaterEqual(len(list(cursor)), 1)
642+
643+
# Assert that the refresh callback has been called.
644+
self.assertEqual(self.refresh_called, 1)
645+
client.close()
646+
647+
def test_reauthenticate_succeeds_command(self):
648+
request_cb = self.create_request_cb()
649+
refresh_cb = self.create_refresh_cb()
650+
651+
# Create a client with the callbacks.
652+
props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb)
653+
654+
print("start of test")
655+
client = MongoClient(self.uri_single, authmechanismproperties=props)
656+
657+
# Perform an insert operation.
658+
client.test.test.insert_one({"a": 1})
659+
660+
# Assert that the refresh callback has not been called.
661+
self.assertEqual(self.refresh_called, 0)
662+
663+
with self.fail_point(
664+
{
665+
"mode": {"times": 1},
666+
"data": {"failCommands": ["count"], "errorCode": 391},
667+
}
668+
):
669+
# Perform a count operation.
670+
cursor = client.test.command(dict(count="test"))
671+
672+
self.assertGreaterEqual(len(list(cursor)), 1)
673+
674+
# Assert that the refresh callback has been called.
675+
self.assertEqual(self.refresh_called, 1)
676+
client.close()
677+
678+
def test_reauthenticate_retries_and_succeeds_with_cache(self):
533679
listener = EventListener()
534680

535681
# Create request and refresh callbacks that return valid credentials

0 commit comments

Comments
 (0)