Skip to content

Commit 67773ce

Browse files
committed
deal with optioanal async type
mypy understands `isinstance(obj, Awaitable)`, but not `isawaitable(obj)`
1 parent ff1fa6e commit 67773ce

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

jupyter_server/auth/identity.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
import uuid
1515
from dataclasses import asdict, dataclass
16-
from typing import TYPE_CHECKING, Any, cast
16+
from typing import TYPE_CHECKING, Any, Awaitable, cast
1717

1818
from tornado import web
1919
from traitlets import Bool, Dict, Type, Unicode, default
@@ -175,7 +175,7 @@ def _token_default(self):
175175

176176
need_token = Bool(True)
177177

178-
def get_user(self, handler: JupyterHandler) -> User | None:
178+
async def get_user(self, handler: JupyterHandler) -> User | None:
179179
"""Get the authenticated user for a request
180180
181181
Must return a :class:`.jupyter_server.auth.User`,
@@ -188,11 +188,18 @@ def get_user(self, handler: JupyterHandler) -> User | None:
188188
if getattr(handler, "_jupyter_current_user", None):
189189
# already authenticated
190190
return handler._jupyter_current_user
191-
token_user = self.get_user_token(handler)
192-
cookie_user = self.get_user_cookie(handler)
191+
_token_user: User | None | Awaitable[User | None] = self.get_user_token(handler)
192+
if isinstance(_token_user, Awaitable):
193+
_token_user = await _token_user
194+
token_user: User | None = _token_user # need second variable name to collapse type
195+
_cookie_user = self.get_user_cookie(handler)
196+
if isinstance(_cookie_user, Awaitable):
197+
_cookie_user = await _cookie_user
198+
cookie_user: User | None = _cookie_user
193199
# prefer token to cookie if both given,
194200
# because token is always explicit
195201
user = token_user or cookie_user
202+
196203
if token_user:
197204
# if token-authenticated, persist user_id in cookie
198205
# if it hasn't already been stored there
@@ -258,16 +265,16 @@ def set_login_cookie(self, handler: JupyterHandler, user: User) -> None:
258265
cookie_options.setdefault("path", handler.base_url)
259266
handler.set_secure_cookie(handler.cookie_name, self.user_to_cookie(user), **cookie_options)
260267

261-
def get_user_cookie(self, handler: JupyterHandler) -> User | None:
268+
def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
262269
"""Get user from a cookie
263270
264271
Calls user_from_cookie to deserialize cookie value
265272
"""
266273
get_secure_cookie_kwargs = handler.settings.get("get_secure_cookie_kwargs", {})
267-
user_cookie = handler.get_secure_cookie(handler.cookie_name, **get_secure_cookie_kwargs)
268-
if not user_cookie:
274+
_user_cookie = handler.get_secure_cookie(handler.cookie_name, **get_secure_cookie_kwargs)
275+
if not _user_cookie:
269276
return None
270-
user_cookie = user_cookie.decode()
277+
user_cookie = _user_cookie.decode()
271278
# TODO: try/catch in case of change in config?
272279
try:
273280
return self.user_from_cookie(user_cookie)
@@ -296,7 +303,7 @@ def get_token(self, handler: JupyterHandler) -> str | None:
296303
user_token = m.group(2)
297304
return user_token
298305

299-
def get_user_token(self, handler: JupyterHandler) -> User | None:
306+
async def get_user_token(self, handler: JupyterHandler) -> User | None:
300307
"""Identify the user based on a token in the URL or Authorization header
301308
302309
Returns:
@@ -321,7 +328,10 @@ def get_user_token(self, handler: JupyterHandler) -> User | None:
321328
# token does not correspond to user-id,
322329
# which is stored in a cookie.
323330
# still check the cookie for the user id
324-
user = self.get_user_cookie(handler)
331+
_user = self.get_user_cookie(handler)
332+
if isinstance(_user, Awaitable):
333+
_user = await _user
334+
user: User | None = _user
325335
if user is None:
326336
user = self.generate_anonymous_user(handler)
327337
return user

jupyter_server/base/handlers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Base Tornado handlers for the Jupyter server."""
22
# Copyright (c) Jupyter Development Team.
33
# Distributed under the terms of the Modified BSD License.
4+
from __future__ import annotations
5+
46
import datetime
57
import functools
68
import inspect
@@ -14,6 +16,7 @@
1416
import warnings
1517
from http.client import responses
1618
from http.cookies import Morsel
19+
from typing import TYPE_CHECKING, Awaitable
1720
from urllib.parse import urlparse
1821

1922
import prometheus_client
@@ -37,6 +40,9 @@
3740
urldecode_unix_socket_path,
3841
)
3942

43+
if TYPE_CHECKING:
44+
from jupyter_server.auth.identity import User
45+
4046
# -----------------------------------------------------------------------------
4147
# Top-level handlers
4248
# -----------------------------------------------------------------------------
@@ -581,6 +587,7 @@ async def prepare(self):
581587

582588
mod_obj = inspect.getmodule(self.get_current_user)
583589
assert mod_obj is not None
590+
user: User | None = None
584591

585592
if type(self.identity_provider) is IdentityProvider and mod_obj.__name__ != __name__:
586593
# check for overridden get_current_user + default IdentityProvider
@@ -594,10 +601,11 @@ async def prepare(self):
594601
)
595602
user = self.get_current_user()
596603
else:
597-
user = self.identity_provider.get_user(self)
598-
if inspect.isawaitable(user):
604+
_user = self.identity_provider.get_user(self)
605+
if isinstance(_user, Awaitable):
599606
# IdentityProvider.get_user _may_ be async
600-
user = await user
607+
_user = await _user
608+
user = _user
601609

602610
# self.current_user for tornado's @web.authenticated
603611
# self._jupyter_current_user for backward-compat in deprecated get_current_user calls

0 commit comments

Comments
 (0)