Skip to content

Commit 625866a

Browse files
committed
new websocket connection API
1 parent af0a059 commit 625866a

File tree

9 files changed

+479
-519
lines changed

9 files changed

+479
-519
lines changed

jupyter_server/base/zmqhandlers.py

Lines changed: 9 additions & 353 deletions
Original file line numberDiff line numberDiff line change
@@ -1,353 +1,9 @@
1-
"""Tornado handlers for WebSocket <-> ZMQ sockets."""
2-
# Copyright (c) Jupyter Development Team.
3-
# Distributed under the terms of the Modified BSD License.
4-
import json
5-
import re
6-
import struct
7-
import sys
8-
from typing import Optional, no_type_check
9-
from urllib.parse import urlparse
10-
11-
import tornado
12-
13-
try:
14-
from jupyter_client.jsonutil import json_default
15-
except ImportError:
16-
from jupyter_client.jsonutil import date_default as json_default
17-
18-
from jupyter_client.jsonutil import extract_dates
19-
from jupyter_client.session import Session
20-
from tornado import ioloop, web
21-
from tornado.iostream import IOStream
22-
from tornado.websocket import WebSocketClosedError, WebSocketHandler
23-
24-
from .handlers import JupyterHandler
25-
26-
27-
def serialize_binary_message(msg):
28-
"""serialize a message as a binary blob
29-
30-
Header:
31-
32-
4 bytes: number of msg parts (nbufs) as 32b int
33-
4 * nbufs bytes: offset for each buffer as integer as 32b int
34-
35-
Offsets are from the start of the buffer, including the header.
36-
37-
Returns
38-
-------
39-
The message serialized to bytes.
40-
41-
"""
42-
# don't modify msg or buffer list in-place
43-
msg = msg.copy()
44-
buffers = list(msg.pop("buffers"))
45-
if sys.version_info < (3, 4):
46-
buffers = [x.tobytes() for x in buffers]
47-
bmsg = json.dumps(msg, default=json_default).encode("utf8")
48-
buffers.insert(0, bmsg)
49-
nbufs = len(buffers)
50-
offsets = [4 * (nbufs + 1)]
51-
for buf in buffers[:-1]:
52-
offsets.append(offsets[-1] + len(buf))
53-
offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
54-
buffers.insert(0, offsets_buf)
55-
return b"".join(buffers)
56-
57-
58-
def deserialize_binary_message(bmsg):
59-
"""deserialize a message from a binary blog
60-
61-
Header:
62-
63-
4 bytes: number of msg parts (nbufs) as 32b int
64-
4 * nbufs bytes: offset for each buffer as integer as 32b int
65-
66-
Offsets are from the start of the buffer, including the header.
67-
68-
Returns
69-
-------
70-
message dictionary
71-
"""
72-
nbufs = struct.unpack("!i", bmsg[:4])[0]
73-
offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
74-
offsets.append(None)
75-
bufs = []
76-
for start, stop in zip(offsets[:-1], offsets[1:]):
77-
bufs.append(bmsg[start:stop])
78-
msg = json.loads(bufs[0].decode("utf8"))
79-
msg["header"] = extract_dates(msg["header"])
80-
msg["parent_header"] = extract_dates(msg["parent_header"])
81-
msg["buffers"] = bufs[1:]
82-
return msg
83-
84-
85-
def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
86-
if pack:
87-
msg_list = [
88-
pack(msg_or_list["header"]),
89-
pack(msg_or_list["parent_header"]),
90-
pack(msg_or_list["metadata"]),
91-
pack(msg_or_list["content"]),
92-
]
93-
else:
94-
msg_list = msg_or_list
95-
channel = channel.encode("utf-8")
96-
offsets: list = []
97-
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
98-
offsets.append(len(channel) + offsets[-1])
99-
for msg in msg_list:
100-
offsets.append(len(msg) + offsets[-1])
101-
offset_number = len(offsets).to_bytes(8, byteorder="little")
102-
offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
103-
bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list)
104-
return bin_msg
105-
106-
107-
def deserialize_msg_from_ws_v1(ws_msg):
108-
offset_number = int.from_bytes(ws_msg[:8], "little")
109-
offsets = [
110-
int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number)
111-
]
112-
channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8")
113-
msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)]
114-
return channel, msg_list
115-
116-
117-
# ping interval for keeping websockets alive (30 seconds)
118-
WS_PING_INTERVAL = 30000
119-
120-
121-
class WebSocketMixin:
122-
"""Mixin for common websocket options"""
123-
124-
ping_callback = None
125-
last_ping = 0.0
126-
last_pong = 0.0
127-
stream = None # type: Optional[IOStream]
128-
129-
@property
130-
def ping_interval(self):
131-
"""The interval for websocket keep-alive pings.
132-
133-
Set ws_ping_interval = 0 to disable pings.
134-
"""
135-
return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]
136-
137-
@property
138-
def ping_timeout(self):
139-
"""If no ping is received in this many milliseconds,
140-
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
141-
Default is max of 3 pings or 30 seconds.
142-
"""
143-
return self.settings.get( # type:ignore[attr-defined]
144-
"ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
145-
)
146-
147-
@no_type_check
148-
def check_origin(self, origin: Optional[str] = None) -> bool:
149-
"""Check Origin == Host or Access-Control-Allow-Origin.
150-
151-
Tornado >= 4 calls this method automatically, raising 403 if it returns False.
152-
"""
153-
154-
if self.allow_origin == "*" or (
155-
hasattr(self, "skip_check_origin") and self.skip_check_origin()
156-
):
157-
return True
158-
159-
host = self.request.headers.get("Host")
160-
if origin is None:
161-
origin = self.get_origin()
162-
163-
# If no origin or host header is provided, assume from script
164-
if origin is None or host is None:
165-
return True
166-
167-
origin = origin.lower()
168-
origin_host = urlparse(origin).netloc
169-
170-
# OK if origin matches host
171-
if origin_host == host:
172-
return True
173-
174-
# Check CORS headers
175-
if self.allow_origin:
176-
allow = self.allow_origin == origin
177-
elif self.allow_origin_pat:
178-
allow = bool(re.match(self.allow_origin_pat, origin))
179-
else:
180-
# No CORS headers deny the request
181-
allow = False
182-
if not allow:
183-
self.log.warning(
184-
"Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
185-
origin,
186-
host,
187-
)
188-
return allow
189-
190-
def clear_cookie(self, *args, **kwargs):
191-
"""meaningless for websockets"""
192-
pass
193-
194-
@no_type_check
195-
def open(self, *args, **kwargs):
196-
self.log.debug("Opening websocket %s", self.request.path)
197-
198-
# start the pinging
199-
if self.ping_interval > 0:
200-
loop = ioloop.IOLoop.current()
201-
self.last_ping = loop.time() # Remember time of last ping
202-
self.last_pong = self.last_ping
203-
self.ping_callback = ioloop.PeriodicCallback(
204-
self.send_ping,
205-
self.ping_interval,
206-
)
207-
self.ping_callback.start()
208-
return super().open(*args, **kwargs)
209-
210-
@no_type_check
211-
def send_ping(self):
212-
"""send a ping to keep the websocket alive"""
213-
if self.ws_connection is None and self.ping_callback is not None:
214-
self.ping_callback.stop()
215-
return
216-
217-
if self.ws_connection.client_terminated:
218-
self.close()
219-
return
220-
221-
# check for timeout on pong. Make sure that we really have sent a recent ping in
222-
# case the machine with both server and client has been suspended since the last ping.
223-
now = ioloop.IOLoop.current().time()
224-
since_last_pong = 1e3 * (now - self.last_pong)
225-
since_last_ping = 1e3 * (now - self.last_ping)
226-
if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout:
227-
self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
228-
self.close()
229-
return
230-
231-
self.ping(b"")
232-
self.last_ping = now
233-
234-
def on_pong(self, data):
235-
self.last_pong = ioloop.IOLoop.current().time()
236-
237-
238-
class ZMQStreamHandler(WebSocketMixin, WebSocketHandler):
239-
240-
if tornado.version_info < (4, 1):
241-
"""Backport send_error from tornado 4.1 to 4.0"""
242-
243-
def send_error(self, *args, **kwargs):
244-
if self.stream is None:
245-
super(WebSocketHandler, self).send_error(*args, **kwargs)
246-
else:
247-
# If we get an uncaught exception during the handshake,
248-
# we have no choice but to abruptly close the connection.
249-
# TODO: for uncaught exceptions after the handshake,
250-
# we can close the connection more gracefully.
251-
self.stream.close()
252-
253-
def _reserialize_reply(self, msg_or_list, channel=None):
254-
"""Reserialize a reply message using JSON.
255-
256-
msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
257-
If it is the zmq list, it will be deserialized with self.session.
258-
259-
This takes the msg list from the ZMQ socket and serializes the result for the websocket.
260-
This method should be used by self._on_zmq_reply to build messages that can
261-
be sent back to the browser.
262-
263-
"""
264-
if isinstance(msg_or_list, dict):
265-
# already unpacked
266-
msg = msg_or_list
267-
else:
268-
idents, msg_list = self.session.feed_identities(msg_or_list)
269-
msg = self.session.deserialize(msg_list)
270-
if channel:
271-
msg["channel"] = channel
272-
if msg["buffers"]:
273-
buf = serialize_binary_message(msg)
274-
return buf
275-
else:
276-
return json.dumps(msg, default=json_default)
277-
278-
def select_subprotocol(self, subprotocols):
279-
preferred_protocol = self.settings.get("kernel_ws_protocol")
280-
if preferred_protocol is None:
281-
preferred_protocol = "v1.kernel.websocket.jupyter.org"
282-
elif preferred_protocol == "":
283-
preferred_protocol = None
284-
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None
285-
# None is the default, "legacy" protocol
286-
return selected_subprotocol
287-
288-
def _on_zmq_reply(self, stream, msg_list):
289-
# Sometimes this gets triggered when the on_close method is scheduled in the
290-
# eventloop but hasn't been called.
291-
if self.ws_connection is None or stream.closed():
292-
self.log.warning("zmq message arrived on closed channel")
293-
self.close()
294-
return
295-
channel = getattr(stream, "channel", None)
296-
if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org":
297-
bin_msg = serialize_msg_to_ws_v1(msg_list, channel)
298-
self.write_message(bin_msg, binary=True)
299-
else:
300-
try:
301-
msg = self._reserialize_reply(msg_list, channel=channel)
302-
except Exception:
303-
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
304-
else:
305-
try:
306-
self.write_message(msg, binary=isinstance(msg, bytes))
307-
except WebSocketClosedError as e:
308-
self.log.warning(str(e))
309-
310-
311-
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):
312-
def set_default_headers(self):
313-
"""Undo the set_default_headers in JupyterHandler
314-
315-
which doesn't make sense for websockets
316-
"""
317-
pass
318-
319-
def pre_get(self):
320-
"""Run before finishing the GET request
321-
322-
Extend this method to add logic that should fire before
323-
the websocket finishes completing.
324-
"""
325-
# authenticate the request before opening the websocket
326-
user = self.current_user
327-
if user is None:
328-
self.log.warning("Couldn't authenticate WebSocket connection")
329-
raise web.HTTPError(403)
330-
331-
# authorize the user.
332-
if not self.authorizer.is_authorized(self, user, "execute", "kernels"):
333-
raise web.HTTPError(403)
334-
335-
if self.get_argument("session_id", None):
336-
self.session.session = self.get_argument("session_id")
337-
else:
338-
self.log.warning("No session ID specified")
339-
340-
async def get(self, *args, **kwargs):
341-
# pre_get can be a coroutine in subclasses
342-
# assign and yield in two step to avoid tornado 3 issues
343-
res = self.pre_get()
344-
await res
345-
res = super().get(*args, **kwargs)
346-
await res
347-
348-
def initialize(self):
349-
self.log.debug("Initializing websocket connection %s", self.request.path)
350-
self.session = Session(config=self.config)
351-
352-
def get_compression_options(self):
353-
return self.settings.get("websocket_compression_options", None)
1+
"""Add deprecation warning here.
2+
"""
3+
from jupyter_server.services.kernels.connection.base import (
4+
deserialize_binary_message,
5+
deserialize_msg_from_ws_v1,
6+
serialize_binary_message,
7+
serialize_msg_to_ws_v1,
8+
)
9+
from jupyter_server.services.kernels.websocket import WebSocketMixin

0 commit comments

Comments
 (0)