Skip to content

Commit 2650341

Browse files
committed
all loop to ensure kernel is alive before connecting working unit tests
1 parent bfc4153 commit 2650341

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

jupyter_server/services/kernels/connection/channels.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import json
3+
import time
34
import weakref
45
from concurrent.futures import Future
56
from textwrap import dedent
@@ -16,6 +17,8 @@
1617
except ImportError:
1718
from jupyter_client.jsonutil import date_default as json_default
1819

20+
from jupyter_client.utils import ensure_async
21+
1922
from jupyter_server.transutils import _i18n
2023

2124
from .abc import KernelWebsocketConnectionABC
@@ -282,7 +285,7 @@ async def _register_session(self):
282285
if (
283286
self.kernel_id in self.multi_kernel_manager
284287
): # only update open sessions if kernel is actively managed
285-
self._open_sessions[self.session_key] = self
288+
self._open_sessions[self.session_key] = self.websocket_handler
286289

287290
async def prepare(self):
288291
# check session collision:
@@ -302,6 +305,12 @@ async def prepare(self):
302305
self.kernel_manager.reason = str(e)
303306
raise web.HTTPError(500, str(e)) from e
304307

308+
t0 = time.time()
309+
while not await ensure_async(self.kernel_manager.is_alive()):
310+
await asyncio.sleep(0.1)
311+
if time.time() - t0 > self.multi_kernel_manager.kernel_info_timeout:
312+
raise TimeoutError("Kernel never reached an 'alive' state.")
313+
305314
self.session.key = self.kernel_manager.session.key
306315
future = self.request_kernel_info()
307316

@@ -360,7 +369,7 @@ def replay(value):
360369
for _, stream in self.channels.items():
361370
if not stream.closed():
362371
stream.close()
363-
self.close()
372+
self.disconnect()
364373
return
365374

366375
self.multi_kernel_manager.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
@@ -376,6 +385,9 @@ def subscribe(value):
376385
ZMQChannelsWebsocketConnection._open_sockets.add(self)
377386
return connected
378387

388+
def close(self):
389+
return self.disconnect()
390+
379391
def disconnect(self):
380392
self.log.debug("Websocket closed %s", self.session_key)
381393
# unregister myself as an open session (only if it's really me)
@@ -536,7 +548,7 @@ def _on_zmq_reply(self, stream, msg_list):
536548
# eventloop but hasn't been called.
537549
if stream.closed():
538550
self.log.warning("zmq message arrived on closed channel")
539-
self.close()
551+
self.disconnect()
540552
return
541553
channel = getattr(stream, "channel", None)
542554
if self.subprotocol == "v1.kernel.websocket.jupyter.org":

jupyter_server/services/kernels/websocket.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ async def pre_get(self):
178178

179179
kernel = self.kernel_manager.get_kernel(self.kernel_id)
180180
self.connection = self.kernel_websocket_connection_class(
181-
parent=kernel,
182-
websocket_handler=self,
181+
parent=kernel, websocket_handler=self, config=self.config
183182
)
184183

185184
if self.get_argument("session_id", None):

tests/services/sessions/test_api.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...utils import expected_http_error
2222

23-
TEST_TIMEOUT = 60
23+
TEST_TIMEOUT = 10
2424

2525

2626
@pytest.fixture(autouse=True)
@@ -556,9 +556,13 @@ async def test_restart_kernel(session_client, jp_base_url, jp_fetch, jp_ws_fetch
556556
model = json.loads(r.body.decode())
557557
assert model["connections"] == 0
558558

559-
# Open a websocket connection.
560-
await jp_ws_fetch("api", "kernels", kid, "channels")
559+
# Add a delay to give the kernel enough time to restart.
560+
# time.sleep(2)
561561

562+
# Open a websocket connection.
563+
ws2 = await jp_ws_fetch("api", "kernels", kid, "channels")
564+
# Close/open websocket
565+
ws2.close()
562566
r = await jp_fetch("api", "kernels", kid, method="GET")
563567
model = json.loads(r.body.decode())
564568
assert model["connections"] == 1

0 commit comments

Comments
 (0)