|
5 | 5 | import logging
|
6 | 6 | import os
|
7 | 7 | import typing as ty
|
| 8 | +from datetime import datetime |
| 9 | +from email.utils import parsedate_to_datetime |
| 10 | +from http.cookies import Morsel, SimpleCookie |
8 | 11 | from socket import gaierror
|
9 | 12 |
|
10 | 13 | from tornado import web
|
@@ -276,6 +279,9 @@ def __init__(self, **kwargs):
|
276 | 279 | super().__init__(**kwargs)
|
277 | 280 | self._static_args = {} # initialized on first use
|
278 | 281 |
|
| 282 | + # store of cookies with store time |
| 283 | + self._cookies = {} # type: ty.Dict[str, ty.Tuple[Morsel, datetime]] |
| 284 | + |
279 | 285 | env_whitelist_default_value = ""
|
280 | 286 | env_whitelist_env = "JUPYTER_GATEWAY_ENV_WHITELIST"
|
281 | 287 | env_whitelist = Unicode(
|
@@ -363,6 +369,23 @@ def launch_timeout_pad_default(self):
|
363 | 369 | )
|
364 | 370 | )
|
365 | 371 |
|
| 372 | + accept_cookies_value = False |
| 373 | + accept_cookies_env = "JUPYTER_GATEWAY_ACCEPT_COOKIES" |
| 374 | + accept_cookies = Bool( |
| 375 | + default_value=accept_cookies_value, |
| 376 | + config=True, |
| 377 | + help="""Accept and manage cookies sent by the service side. This is often useful |
| 378 | + for load balancers to decide which backend node to use. |
| 379 | + (JUPYTER_GATEWAY_ACCEPT_COOKIES env var)""", |
| 380 | + ) |
| 381 | + |
| 382 | + @default("accept_cookies") |
| 383 | + def accept_cookies_default(self): |
| 384 | + return bool( |
| 385 | + os.environ.get(self.accept_cookies_env, str(self.accept_cookies_value).lower()) |
| 386 | + not in ["no", "false"] |
| 387 | + ) |
| 388 | + |
366 | 389 | @property
|
367 | 390 | def gateway_enabled(self):
|
368 | 391 | return bool(self.url is not None and len(self.url) > 0)
|
@@ -424,8 +447,65 @@ def load_connection_args(self, **kwargs):
|
424 | 447 | else:
|
425 | 448 | kwargs[arg] = static_value
|
426 | 449 |
|
| 450 | + if self.accept_cookies: |
| 451 | + self._update_cookie_header(kwargs) |
| 452 | + |
427 | 453 | return kwargs
|
428 | 454 |
|
| 455 | + def update_cookies(self, cookie: SimpleCookie) -> None: |
| 456 | + """Update cookies from existing requests for load balancers""" |
| 457 | + if not self.accept_cookies: |
| 458 | + return |
| 459 | + |
| 460 | + store_time = datetime.now() |
| 461 | + for key, item in cookie.items(): |
| 462 | + # Convert "expires" arg into "max-age" to facilitate expiration management. |
| 463 | + # As "max-age" has precedence, ignore "expires" when "max-age" exists. |
| 464 | + if item.get("expires") and not item.get("max-age"): |
| 465 | + expire_timedelta = parsedate_to_datetime(item["expires"]) - store_time |
| 466 | + item["max-age"] = str(expire_timedelta.total_seconds()) |
| 467 | + |
| 468 | + self._cookies[key] = (item, store_time) |
| 469 | + |
| 470 | + def _clear_expired_cookies(self) -> None: |
| 471 | + check_time = datetime.now() |
| 472 | + expired_keys = [] |
| 473 | + |
| 474 | + for key, (morsel, store_time) in self._cookies.items(): |
| 475 | + cookie_max_age = morsel.get("max-age") |
| 476 | + if not cookie_max_age: |
| 477 | + continue |
| 478 | + expired_timedelta = check_time - store_time |
| 479 | + if expired_timedelta.total_seconds() > float(cookie_max_age): |
| 480 | + expired_keys.append(key) |
| 481 | + |
| 482 | + for key in expired_keys: |
| 483 | + self._cookies.pop(key) |
| 484 | + |
| 485 | + def _update_cookie_header(self, connection_args: dict) -> None: |
| 486 | + self._clear_expired_cookies() |
| 487 | + |
| 488 | + gateway_cookie_values = "; ".join( |
| 489 | + f"{name}={morsel.coded_value}" for name, (morsel, _time) in self._cookies.items() |
| 490 | + ) |
| 491 | + if gateway_cookie_values: |
| 492 | + headers = connection_args.get("headers", {}) |
| 493 | + |
| 494 | + # As headers are case-insensitive, we get existing name of cookie header, |
| 495 | + # or use "Cookie" by default. |
| 496 | + cookie_header_name = next( |
| 497 | + (header_key for header_key in headers if header_key.lower() == "cookie"), |
| 498 | + "Cookie", |
| 499 | + ) |
| 500 | + existing_cookie = headers.get(cookie_header_name) |
| 501 | + |
| 502 | + # merge gateway-managed cookies with cookies already in arguments |
| 503 | + if existing_cookie: |
| 504 | + gateway_cookie_values = existing_cookie + "; " + gateway_cookie_values |
| 505 | + headers[cookie_header_name] = gateway_cookie_values |
| 506 | + |
| 507 | + connection_args["headers"] = headers |
| 508 | + |
429 | 509 |
|
430 | 510 | class RetryableHTTPClient:
|
431 | 511 | """
|
@@ -524,4 +604,11 @@ async def gateway_request(endpoint: str, **kwargs: ty.Any) -> HTTPResponse:
|
524 | 604 | f"appear to be valid. Ensure gateway url is valid and the Gateway instance is running.",
|
525 | 605 | ) from e
|
526 | 606 |
|
| 607 | + if GatewayClient.instance().accept_cookies: |
| 608 | + # Update cookies on GatewayClient from server if configured. |
| 609 | + cookie_values = response.headers.get("Set-Cookie") |
| 610 | + if cookie_values: |
| 611 | + cookie: SimpleCookie = SimpleCookie() |
| 612 | + cookie.load(cookie_values) |
| 613 | + GatewayClient.instance().update_cookies(cookie) |
527 | 614 | return response
|
0 commit comments