Skip to content

Commit 53380a7

Browse files
Fix kernel WebSocket protocol (#1110)
1 parent c2275a7 commit 53380a7

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

jupyter_server/services/kernels/connection/channels.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import weakref
55
from concurrent.futures import Future
66
from textwrap import dedent
7+
from typing import Dict as Dict_t
78
from typing import MutableSet
89

910
from jupyter_client import protocol_version as client_protocol_version
@@ -21,6 +22,7 @@
2122

2223
from jupyter_server.transutils import _i18n
2324

25+
from ..websocket import KernelWebsocketHandler
2426
from .abc import KernelWebsocketConnectionABC
2527
from .base import (
2628
BaseKernelWebsocketConnection,
@@ -103,7 +105,7 @@ def write_message(self):
103105
# class-level registry of open sessions
104106
# allows checking for conflict on session-id,
105107
# which is used as a zmq identity and must be unique.
106-
_open_sessions: dict = {}
108+
_open_sessions: Dict_t[str, KernelWebsocketHandler] = {}
107109
_open_sockets: MutableSet["ZMQChannelsWebsocketConnection"] = weakref.WeakSet()
108110

109111
_kernel_info_future: Future
@@ -391,7 +393,7 @@ def close(self):
391393
def disconnect(self):
392394
self.log.debug("Websocket closed %s", self.session_key)
393395
# unregister myself as an open session (only if it's really me)
394-
if self._open_sessions.get(self.session_key) is self:
396+
if self._open_sessions.get(self.session_key) is self.websocket_handler:
395397
self._open_sessions.pop(self.session_key)
396398

397399
if self.kernel_id in self.multi_kernel_manager:
@@ -536,16 +538,6 @@ def _reserialize_reply(self, msg_or_list, channel=None):
536538
else:
537539
return json.dumps(msg, default=json_default)
538540

539-
def select_subprotocol(self, subprotocols):
540-
preferred_protocol = self.kernel_ws_protocol
541-
if preferred_protocol is None:
542-
preferred_protocol = "v1.kernel.websocket.jupyter.org"
543-
elif preferred_protocol == "":
544-
preferred_protocol = None
545-
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
546-
# None is the default, "legacy" protocol
547-
return selected_subprotocol
548-
549541
def _on_zmq_reply(self, stream, msg_list):
550542
# Sometimes this gets triggered when the on_close method is scheduled in the
551543
# eventloop but hasn't been called.

jupyter_server/services/kernels/websocket.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler):
17-
"""The kernels websocket should connecte"""
17+
"""The kernels websocket should connect"""
1818

1919
auth_resource = AUTH_RESOURCE
2020

@@ -75,6 +75,16 @@ def on_close(self):
7575
self.connection.disconnect()
7676
self.connection = None
7777

78+
def select_subprotocol(self, subprotocols):
79+
preferred_protocol = self.connection.kernel_ws_protocol
80+
if preferred_protocol is None:
81+
preferred_protocol = "v1.kernel.websocket.jupyter.org"
82+
elif preferred_protocol == "":
83+
preferred_protocol = None
84+
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
85+
# None is the default, "legacy" protocol
86+
return selected_subprotocol
87+
7888

7989
default_handlers = [
8090
(r"/api/kernels/%s/channels" % _kernel_id_regex, KernelWebsocketHandler),

tests/services/kernels/test_connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from jupyter_client.jsonutil import json_clean, json_default
66
from jupyter_client.session import Session
77
from tornado.httpserver import HTTPRequest
8-
from tornado.websocket import WebSocketHandler
98

109
from jupyter_server.serverapp import ServerApp
1110
from jupyter_server.services.kernels.connection.channels import (
1211
ZMQChannelsWebsocketConnection,
1312
)
13+
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler
1414

1515

1616
async def test_websocket_connection(jp_serverapp):
@@ -19,10 +19,11 @@ async def test_websocket_connection(jp_serverapp):
1919
kernel = app.kernel_manager.get_kernel(kernel_id)
2020
request = HTTPRequest("foo", "GET")
2121
request.connection = MagicMock()
22-
handler = WebSocketHandler(app.web_app, request)
22+
handler = KernelWebsocketHandler(app.web_app, request)
2323
handler.ws_connection = MagicMock()
2424
handler.ws_connection.is_closing = lambda: False
2525
conn = ZMQChannelsWebsocketConnection(parent=kernel, websocket_handler=handler)
26+
handler.connection = conn
2627
await conn.prepare()
2728
conn.connect()
2829
await asyncio.wrap_future(conn.nudge())
@@ -37,7 +38,7 @@ async def test_websocket_connection(jp_serverapp):
3738
conn.handle_incoming_message(data)
3839
conn.handle_outgoing_message("iopub", session.serialize(msg))
3940
assert (
40-
conn.select_subprotocol(["v1.kernel.websocket.jupyter.org"])
41+
conn.websocket_handler.select_subprotocol(["v1.kernel.websocket.jupyter.org"])
4142
== "v1.kernel.websocket.jupyter.org"
4243
)
4344
conn.write_stderr("test", {})

0 commit comments

Comments
 (0)