Skip to content

Commit 9a2708e

Browse files
authored
Accept and manage cookies when requesting gateways (#969)
* Add support for stickiness cookies on load balancers * Add case * Simplify arguments * Add tests for arguments * Fix according to comments
1 parent 4c7bbfa commit 9a2708e

File tree

2 files changed

+139
-1
lines changed

2 files changed

+139
-1
lines changed

jupyter_server/gateway/gateway_client.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import logging
66
import os
77
import typing as ty
8+
from datetime import datetime
9+
from email.utils import parsedate_to_datetime
10+
from http.cookies import Morsel, SimpleCookie
811
from socket import gaierror
912

1013
from tornado import web
@@ -276,6 +279,9 @@ def __init__(self, **kwargs):
276279
super().__init__(**kwargs)
277280
self._static_args = {} # initialized on first use
278281

282+
# store of cookies with store time
283+
self._cookies = {} # type: ty.Dict[str, ty.Tuple[Morsel, datetime]]
284+
279285
env_whitelist_default_value = ""
280286
env_whitelist_env = "JUPYTER_GATEWAY_ENV_WHITELIST"
281287
env_whitelist = Unicode(
@@ -363,6 +369,23 @@ def launch_timeout_pad_default(self):
363369
)
364370
)
365371

372+
accept_cookies_value = False
373+
accept_cookies_env = "JUPYTER_GATEWAY_ACCEPT_COOKIES"
374+
accept_cookies = Bool(
375+
default_value=accept_cookies_value,
376+
config=True,
377+
help="""Accept and manage cookies sent by the service side. This is often useful
378+
for load balancers to decide which backend node to use.
379+
(JUPYTER_GATEWAY_ACCEPT_COOKIES env var)""",
380+
)
381+
382+
@default("accept_cookies")
383+
def accept_cookies_default(self):
384+
return bool(
385+
os.environ.get(self.accept_cookies_env, str(self.accept_cookies_value).lower())
386+
not in ["no", "false"]
387+
)
388+
366389
@property
367390
def gateway_enabled(self):
368391
return bool(self.url is not None and len(self.url) > 0)
@@ -424,8 +447,65 @@ def load_connection_args(self, **kwargs):
424447
else:
425448
kwargs[arg] = static_value
426449

450+
if self.accept_cookies:
451+
self._update_cookie_header(kwargs)
452+
427453
return kwargs
428454

455+
def update_cookies(self, cookie: SimpleCookie) -> None:
456+
"""Update cookies from existing requests for load balancers"""
457+
if not self.accept_cookies:
458+
return
459+
460+
store_time = datetime.now()
461+
for key, item in cookie.items():
462+
# Convert "expires" arg into "max-age" to facilitate expiration management.
463+
# As "max-age" has precedence, ignore "expires" when "max-age" exists.
464+
if item.get("expires") and not item.get("max-age"):
465+
expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time
466+
item["max-age"] = str(expire_timedelta.total_seconds())
467+
468+
self._cookies[key] = (item, store_time)
469+
470+
def _clear_expired_cookies(self) -> None:
471+
check_time = datetime.now()
472+
expired_keys = []
473+
474+
for key, (morsel, store_time) in self._cookies.items():
475+
cookie_max_age = morsel.get("max-age")
476+
if not cookie_max_age:
477+
continue
478+
expired_timedelta = check_time - store_time
479+
if expired_timedelta.total_seconds() > float(cookie_max_age):
480+
expired_keys.append(key)
481+
482+
for key in expired_keys:
483+
self._cookies.pop(key)
484+
485+
def _update_cookie_header(self, connection_args: dict) -> None:
486+
self._clear_expired_cookies()
487+
488+
gateway_cookie_values = "; ".join(
489+
f"{name}={morsel.coded_value}" for name, (morsel, _time) in self._cookies.items()
490+
)
491+
if gateway_cookie_values:
492+
headers = connection_args.get("headers", {})
493+
494+
# As headers are case-insensitive, we get existing name of cookie header,
495+
# or use "Cookie" by default.
496+
cookie_header_name = next(
497+
(header_key for header_key in headers if header_key.lower() == "cookie"),
498+
"Cookie",
499+
)
500+
existing_cookie = headers.get(cookie_header_name)
501+
502+
# merge gateway-managed cookies with cookies already in arguments
503+
if existing_cookie:
504+
gateway_cookie_values = existing_cookie + "; " + gateway_cookie_values
505+
headers[cookie_header_name] = gateway_cookie_values
506+
507+
connection_args["headers"] = headers
508+
429509

430510
class RetryableHTTPClient:
431511
"""
@@ -524,4 +604,11 @@ async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
524604
f"appear to be valid. Ensure gateway url is valid and the Gateway instance is running.",
525605
) from e
526606

607+
if GatewayClient.instance().accept_cookies:
608+
# Update cookies on GatewayClient from server if configured.
609+
cookie_values = response.headers.get("Set-Cookie")
610+
if cookie_values:
611+
cookie: SimpleCookie = SimpleCookie()
612+
cookie.load(cookie_values)
613+
GatewayClient.instance().update_cookies(cookie)
527614
return response

tests/test_gateway.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import logging
55
import os
66
import uuid
7-
from datetime import datetime
7+
from datetime import datetime, timedelta
8+
from email.utils import format_datetime
9+
from http.cookies import SimpleCookie
810
from io import BytesIO
911
from queue import Empty
1012
from unittest.mock import MagicMock, patch
@@ -187,6 +189,7 @@ def init_gateway(monkeypatch):
187189
monkeypatch.setenv("JUPYTER_GATEWAY_REQUEST_TIMEOUT", "44.4")
188190
monkeypatch.setenv("JUPYTER_GATEWAY_CONNECT_TIMEOUT", "44.4")
189191
monkeypatch.setenv("JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD", "1.1")
192+
monkeypatch.setenv("JUPYTER_GATEWAY_ACCEPT_COOKIES", "false")
190193
yield
191194
GatewayClient.clear_instance()
192195

@@ -200,6 +203,7 @@ async def test_gateway_env_options(init_gateway, jp_serverapp):
200203
)
201204
assert jp_serverapp.gateway_config.connect_timeout == 44.4
202205
assert jp_serverapp.gateway_config.launch_timeout_pad == 1.1
206+
assert jp_serverapp.gateway_config.accept_cookies is False
203207

204208
GatewayClient.instance().init_static_args()
205209
assert GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT == 43
@@ -259,6 +263,53 @@ async def test_gateway_request_timeout_pad_option(
259263
GatewayClient.clear_instance()
260264

261265

266+
cookie_expire_time = format_datetime(datetime.now() + timedelta(seconds=180))
267+
268+
269+
@pytest.mark.parametrize(
270+
"accept_cookies,expire_arg,expire_param,existing_cookies,cookie_exists",
271+
[
272+
(False, None, None, "EXISTING=1", False),
273+
(True, None, None, "EXISTING=1", True),
274+
(True, "Expires", cookie_expire_time, None, True),
275+
(True, "Max-Age", "-360", "EXISTING=1", False),
276+
],
277+
)
278+
async def test_gateway_request_with_expiring_cookies(
279+
jp_configurable_serverapp,
280+
accept_cookies,
281+
expire_arg,
282+
expire_param,
283+
existing_cookies,
284+
cookie_exists,
285+
):
286+
argv = [f"--GatewayClient.accept_cookies={accept_cookies}"]
287+
288+
GatewayClient.clear_instance()
289+
jp_configurable_serverapp(argv=argv)
290+
291+
cookie: SimpleCookie = SimpleCookie()
292+
cookie.load("SERVERID=1234567; Path=/")
293+
if expire_arg:
294+
cookie["SERVERID"][expire_arg] = expire_param
295+
296+
GatewayClient.instance().update_cookies(cookie)
297+
298+
args = {}
299+
if existing_cookies:
300+
args["headers"] = {"Cookie": existing_cookies}
301+
connection_args = GatewayClient.instance().load_connection_args(**args)
302+
303+
if not cookie_exists:
304+
assert "SERVERID" not in (connection_args["headers"].get("Cookie") or "")
305+
else:
306+
assert "SERVERID" in connection_args["headers"].get("Cookie")
307+
if existing_cookies:
308+
assert "EXISTING" in connection_args["headers"].get("Cookie")
309+
310+
GatewayClient.clear_instance()
311+
312+
262313
async def test_gateway_class_mappings(init_gateway, jp_serverapp):
263314
# Ensure appropriate class mappings are in place.
264315
assert jp_serverapp.kernel_manager_class.__name__ == "GatewayMappingKernelManager"

0 commit comments

Comments
 (0)