|
1 |
| -"""Tornado handlers for WebSocket <-> ZMQ sockets.""" |
2 |
| -# Copyright (c) Jupyter Development Team. |
3 |
| -# Distributed under the terms of the Modified BSD License. |
4 |
| -import json |
5 |
| -import re |
6 |
| -import struct |
7 |
| -import sys |
8 |
| -from typing import Optional, no_type_check |
9 |
| -from urllib.parse import urlparse |
10 |
| - |
11 |
| -import tornado |
12 |
| - |
13 |
| -try: |
14 |
| - from jupyter_client.jsonutil import json_default |
15 |
| -except ImportError: |
16 |
| - from jupyter_client.jsonutil import date_default as json_default |
17 |
| - |
18 |
| -from jupyter_client.jsonutil import extract_dates |
19 |
| -from jupyter_client.session import Session |
20 |
| -from tornado import ioloop, web |
21 |
| -from tornado.iostream import IOStream |
22 |
| -from tornado.websocket import WebSocketClosedError, WebSocketHandler |
23 |
| - |
24 |
| -from .handlers import JupyterHandler |
25 |
| - |
26 |
| - |
27 |
| -def serialize_binary_message(msg): |
28 |
| - """serialize a message as a binary blob |
29 |
| -
|
30 |
| - Header: |
31 |
| -
|
32 |
| - 4 bytes: number of msg parts (nbufs) as 32b int |
33 |
| - 4 * nbufs bytes: offset for each buffer as integer as 32b int |
34 |
| -
|
35 |
| - Offsets are from the start of the buffer, including the header. |
36 |
| -
|
37 |
| - Returns |
38 |
| - ------- |
39 |
| - The message serialized to bytes. |
40 |
| -
|
41 |
| - """ |
42 |
| - # don't modify msg or buffer list in-place |
43 |
| - msg = msg.copy() |
44 |
| - buffers = list(msg.pop("buffers")) |
45 |
| - if sys.version_info < (3, 4): |
46 |
| - buffers = [x.tobytes() for x in buffers] |
47 |
| - bmsg = json.dumps(msg, default=json_default).encode("utf8") |
48 |
| - buffers.insert(0, bmsg) |
49 |
| - nbufs = len(buffers) |
50 |
| - offsets = [4 * (nbufs + 1)] |
51 |
| - for buf in buffers[:-1]: |
52 |
| - offsets.append(offsets[-1] + len(buf)) |
53 |
| - offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets) |
54 |
| - buffers.insert(0, offsets_buf) |
55 |
| - return b"".join(buffers) |
56 |
| - |
57 |
| - |
58 |
| -def deserialize_binary_message(bmsg): |
59 |
| - """deserialize a message from a binary blog |
60 |
| -
|
61 |
| - Header: |
62 |
| -
|
63 |
| - 4 bytes: number of msg parts (nbufs) as 32b int |
64 |
| - 4 * nbufs bytes: offset for each buffer as integer as 32b int |
65 |
| -
|
66 |
| - Offsets are from the start of the buffer, including the header. |
67 |
| -
|
68 |
| - Returns |
69 |
| - ------- |
70 |
| - message dictionary |
71 |
| - """ |
72 |
| - nbufs = struct.unpack("!i", bmsg[:4])[0] |
73 |
| - offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)])) |
74 |
| - offsets.append(None) |
75 |
| - bufs = [] |
76 |
| - for start, stop in zip(offsets[:-1], offsets[1:]): |
77 |
| - bufs.append(bmsg[start:stop]) |
78 |
| - msg = json.loads(bufs[0].decode("utf8")) |
79 |
| - msg["header"] = extract_dates(msg["header"]) |
80 |
| - msg["parent_header"] = extract_dates(msg["parent_header"]) |
81 |
| - msg["buffers"] = bufs[1:] |
82 |
| - return msg |
83 |
| - |
84 |
| - |
85 |
| -def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): |
86 |
| - if pack: |
87 |
| - msg_list = [ |
88 |
| - pack(msg_or_list["header"]), |
89 |
| - pack(msg_or_list["parent_header"]), |
90 |
| - pack(msg_or_list["metadata"]), |
91 |
| - pack(msg_or_list["content"]), |
92 |
| - ] |
93 |
| - else: |
94 |
| - msg_list = msg_or_list |
95 |
| - channel = channel.encode("utf-8") |
96 |
| - offsets: list = [] |
97 |
| - offsets.append(8 * (1 + 1 + len(msg_list) + 1)) |
98 |
| - offsets.append(len(channel) + offsets[-1]) |
99 |
| - for msg in msg_list: |
100 |
| - offsets.append(len(msg) + offsets[-1]) |
101 |
| - offset_number = len(offsets).to_bytes(8, byteorder="little") |
102 |
| - offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets] |
103 |
| - bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list) |
104 |
| - return bin_msg |
105 |
| - |
106 |
| - |
107 |
| -def deserialize_msg_from_ws_v1(ws_msg): |
108 |
| - offset_number = int.from_bytes(ws_msg[:8], "little") |
109 |
| - offsets = [ |
110 |
| - int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number) |
111 |
| - ] |
112 |
| - channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") |
113 |
| - msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)] |
114 |
| - return channel, msg_list |
115 |
| - |
116 |
| - |
117 |
| -# ping interval for keeping websockets alive (30 seconds) |
118 |
| -WS_PING_INTERVAL = 30000 |
119 |
| - |
120 |
| - |
121 |
| -class WebSocketMixin: |
122 |
| - """Mixin for common websocket options""" |
123 |
| - |
124 |
| - ping_callback = None |
125 |
| - last_ping = 0.0 |
126 |
| - last_pong = 0.0 |
127 |
| - stream = None # type: Optional[IOStream] |
128 |
| - |
129 |
| - @property |
130 |
| - def ping_interval(self): |
131 |
| - """The interval for websocket keep-alive pings. |
132 |
| -
|
133 |
| - Set ws_ping_interval = 0 to disable pings. |
134 |
| - """ |
135 |
| - return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] |
136 |
| - |
137 |
| - @property |
138 |
| - def ping_timeout(self): |
139 |
| - """If no ping is received in this many milliseconds, |
140 |
| - close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). |
141 |
| - Default is max of 3 pings or 30 seconds. |
142 |
| - """ |
143 |
| - return self.settings.get( # type:ignore[attr-defined] |
144 |
| - "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) |
145 |
| - ) |
146 |
| - |
147 |
| - @no_type_check |
148 |
| - def check_origin(self, origin: Optional[str] = None) -> bool: |
149 |
| - """Check Origin == Host or Access-Control-Allow-Origin. |
150 |
| -
|
151 |
| - Tornado >= 4 calls this method automatically, raising 403 if it returns False. |
152 |
| - """ |
153 |
| - |
154 |
| - if self.allow_origin == "*" or ( |
155 |
| - hasattr(self, "skip_check_origin") and self.skip_check_origin() |
156 |
| - ): |
157 |
| - return True |
158 |
| - |
159 |
| - host = self.request.headers.get("Host") |
160 |
| - if origin is None: |
161 |
| - origin = self.get_origin() |
162 |
| - |
163 |
| - # If no origin or host header is provided, assume from script |
164 |
| - if origin is None or host is None: |
165 |
| - return True |
166 |
| - |
167 |
| - origin = origin.lower() |
168 |
| - origin_host = urlparse(origin).netloc |
169 |
| - |
170 |
| - # OK if origin matches host |
171 |
| - if origin_host == host: |
172 |
| - return True |
173 |
| - |
174 |
| - # Check CORS headers |
175 |
| - if self.allow_origin: |
176 |
| - allow = self.allow_origin == origin |
177 |
| - elif self.allow_origin_pat: |
178 |
| - allow = bool(re.match(self.allow_origin_pat, origin)) |
179 |
| - else: |
180 |
| - # No CORS headers deny the request |
181 |
| - allow = False |
182 |
| - if not allow: |
183 |
| - self.log.warning( |
184 |
| - "Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s", |
185 |
| - origin, |
186 |
| - host, |
187 |
| - ) |
188 |
| - return allow |
189 |
| - |
190 |
| - def clear_cookie(self, *args, **kwargs): |
191 |
| - """meaningless for websockets""" |
192 |
| - pass |
193 |
| - |
194 |
| - @no_type_check |
195 |
| - def open(self, *args, **kwargs): |
196 |
| - self.log.debug("Opening websocket %s", self.request.path) |
197 |
| - |
198 |
| - # start the pinging |
199 |
| - if self.ping_interval > 0: |
200 |
| - loop = ioloop.IOLoop.current() |
201 |
| - self.last_ping = loop.time() # Remember time of last ping |
202 |
| - self.last_pong = self.last_ping |
203 |
| - self.ping_callback = ioloop.PeriodicCallback( |
204 |
| - self.send_ping, |
205 |
| - self.ping_interval, |
206 |
| - ) |
207 |
| - self.ping_callback.start() |
208 |
| - return super().open(*args, **kwargs) |
209 |
| - |
210 |
| - @no_type_check |
211 |
| - def send_ping(self): |
212 |
| - """send a ping to keep the websocket alive""" |
213 |
| - if self.ws_connection is None and self.ping_callback is not None: |
214 |
| - self.ping_callback.stop() |
215 |
| - return |
216 |
| - |
217 |
| - if self.ws_connection.client_terminated: |
218 |
| - self.close() |
219 |
| - return |
220 |
| - |
221 |
| - # check for timeout on pong. Make sure that we really have sent a recent ping in |
222 |
| - # case the machine with both server and client has been suspended since the last ping. |
223 |
| - now = ioloop.IOLoop.current().time() |
224 |
| - since_last_pong = 1e3 * (now - self.last_pong) |
225 |
| - since_last_ping = 1e3 * (now - self.last_ping) |
226 |
| - if since_last_ping < 2 * self.ping_interval and since_last_pong > self.ping_timeout: |
227 |
| - self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong) |
228 |
| - self.close() |
229 |
| - return |
230 |
| - |
231 |
| - self.ping(b"") |
232 |
| - self.last_ping = now |
233 |
| - |
234 |
| - def on_pong(self, data): |
235 |
| - self.last_pong = ioloop.IOLoop.current().time() |
236 |
| - |
237 |
| - |
238 |
| -class ZMQStreamHandler(WebSocketMixin, WebSocketHandler): |
239 |
| - |
240 |
| - if tornado.version_info < (4, 1): |
241 |
| - """Backport send_error from tornado 4.1 to 4.0""" |
242 |
| - |
243 |
| - def send_error(self, *args, **kwargs): |
244 |
| - if self.stream is None: |
245 |
| - super(WebSocketHandler, self).send_error(*args, **kwargs) |
246 |
| - else: |
247 |
| - # If we get an uncaught exception during the handshake, |
248 |
| - # we have no choice but to abruptly close the connection. |
249 |
| - # TODO: for uncaught exceptions after the handshake, |
250 |
| - # we can close the connection more gracefully. |
251 |
| - self.stream.close() |
252 |
| - |
253 |
| - def _reserialize_reply(self, msg_or_list, channel=None): |
254 |
| - """Reserialize a reply message using JSON. |
255 |
| -
|
256 |
| - msg_or_list can be an already-deserialized msg dict or the zmq buffer list. |
257 |
| - If it is the zmq list, it will be deserialized with self.session. |
258 |
| -
|
259 |
| - This takes the msg list from the ZMQ socket and serializes the result for the websocket. |
260 |
| - This method should be used by self._on_zmq_reply to build messages that can |
261 |
| - be sent back to the browser. |
262 |
| -
|
263 |
| - """ |
264 |
| - if isinstance(msg_or_list, dict): |
265 |
| - # already unpacked |
266 |
| - msg = msg_or_list |
267 |
| - else: |
268 |
| - idents, msg_list = self.session.feed_identities(msg_or_list) |
269 |
| - msg = self.session.deserialize(msg_list) |
270 |
| - if channel: |
271 |
| - msg["channel"] = channel |
272 |
| - if msg["buffers"]: |
273 |
| - buf = serialize_binary_message(msg) |
274 |
| - return buf |
275 |
| - else: |
276 |
| - return json.dumps(msg, default=json_default) |
277 |
| - |
278 |
| - def select_subprotocol(self, subprotocols): |
279 |
| - preferred_protocol = self.settings.get("kernel_ws_protocol") |
280 |
| - if preferred_protocol is None: |
281 |
| - preferred_protocol = "v1.kernel.websocket.jupyter.org" |
282 |
| - elif preferred_protocol == "": |
283 |
| - preferred_protocol = None |
284 |
| - selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None |
285 |
| - # None is the default, "legacy" protocol |
286 |
| - return selected_subprotocol |
287 |
| - |
288 |
| - def _on_zmq_reply(self, stream, msg_list): |
289 |
| - # Sometimes this gets triggered when the on_close method is scheduled in the |
290 |
| - # eventloop but hasn't been called. |
291 |
| - if self.ws_connection is None or stream.closed(): |
292 |
| - self.log.warning("zmq message arrived on closed channel") |
293 |
| - self.close() |
294 |
| - return |
295 |
| - channel = getattr(stream, "channel", None) |
296 |
| - if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": |
297 |
| - bin_msg = serialize_msg_to_ws_v1(msg_list, channel) |
298 |
| - self.write_message(bin_msg, binary=True) |
299 |
| - else: |
300 |
| - try: |
301 |
| - msg = self._reserialize_reply(msg_list, channel=channel) |
302 |
| - except Exception: |
303 |
| - self.log.critical("Malformed message: %r" % msg_list, exc_info=True) |
304 |
| - else: |
305 |
| - try: |
306 |
| - self.write_message(msg, binary=isinstance(msg, bytes)) |
307 |
| - except WebSocketClosedError as e: |
308 |
| - self.log.warning(str(e)) |
309 |
| - |
310 |
| - |
311 |
| -class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler): |
312 |
| - def set_default_headers(self): |
313 |
| - """Undo the set_default_headers in JupyterHandler |
314 |
| -
|
315 |
| - which doesn't make sense for websockets |
316 |
| - """ |
317 |
| - pass |
318 |
| - |
319 |
| - def pre_get(self): |
320 |
| - """Run before finishing the GET request |
321 |
| -
|
322 |
| - Extend this method to add logic that should fire before |
323 |
| - the websocket finishes completing. |
324 |
| - """ |
325 |
| - # authenticate the request before opening the websocket |
326 |
| - user = self.current_user |
327 |
| - if user is None: |
328 |
| - self.log.warning("Couldn't authenticate WebSocket connection") |
329 |
| - raise web.HTTPError(403) |
330 |
| - |
331 |
| - # authorize the user. |
332 |
| - if not self.authorizer.is_authorized(self, user, "execute", "kernels"): |
333 |
| - raise web.HTTPError(403) |
334 |
| - |
335 |
| - if self.get_argument("session_id", None): |
336 |
| - self.session.session = self.get_argument("session_id") |
337 |
| - else: |
338 |
| - self.log.warning("No session ID specified") |
339 |
| - |
340 |
| - async def get(self, *args, **kwargs): |
341 |
| - # pre_get can be a coroutine in subclasses |
342 |
| - # assign and yield in two step to avoid tornado 3 issues |
343 |
| - res = self.pre_get() |
344 |
| - await res |
345 |
| - res = super().get(*args, **kwargs) |
346 |
| - await res |
347 |
| - |
348 |
| - def initialize(self): |
349 |
| - self.log.debug("Initializing websocket connection %s", self.request.path) |
350 |
| - self.session = Session(config=self.config) |
351 |
| - |
352 |
| - def get_compression_options(self): |
353 |
| - return self.settings.get("websocket_compression_options", None) |
| 1 | +"""Add deprecation warning here. |
| 2 | +""" |
| 3 | +from jupyter_server.services.kernels.connection.base import ( |
| 4 | + deserialize_binary_message, |
| 5 | + deserialize_msg_from_ws_v1, |
| 6 | + serialize_binary_message, |
| 7 | + serialize_msg_to_ws_v1, |
| 8 | +) |
| 9 | +from jupyter_server.services.kernels.websocket import WebSocketMixin |
0 commit comments