Skip to content

Commit a6d2d35

Browse files
authored
Improve the busy/idle execution state tracking for kernels. (#1429)
1 parent 8fe5b58 commit a6d2d35

File tree

3 files changed

+281
-8
lines changed

3 files changed

+281
-8
lines changed

jupyter_server/services/kernels/kernelmanager.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,7 @@ async def _async_start_kernel( # type:ignore[override]
232232
kwargs["kernel_id"] = kernel_id
233233
kernel_id = await self.pinned_superclass._async_start_kernel(self, **kwargs)
234234
self._kernel_connections[kernel_id] = 0
235-
task = asyncio.create_task(self._finish_kernel_start(kernel_id))
236-
if not getattr(self, "use_pending_kernels", None):
237-
await task
238-
else:
239-
self._pending_kernel_tasks[kernel_id] = task
235+
240236
# add busy/activity markers:
241237
kernel = self.get_kernel(kernel_id)
242238
kernel.execution_state = "starting" # type:ignore[attr-defined]
@@ -250,6 +246,12 @@ async def _async_start_kernel( # type:ignore[override]
250246
if env and isinstance(env, dict): # type:ignore[unreachable]
251247
self.log.debug("Kernel argument 'env' passed with: %r", list(env.keys())) # type:ignore[unreachable]
252248

249+
task = asyncio.create_task(self._finish_kernel_start(kernel_id))
250+
if not getattr(self, "use_pending_kernels", None):
251+
await task
252+
else:
253+
self._pending_kernel_tasks[kernel_id] = task
254+
253255
# Increase the metric of number of kernels running
254256
# for the relevant kernel type by 1
255257
KERNEL_CURRENTLY_RUNNING_TOTAL.labels(type=self._kernels[kernel_id].kernel_name).inc()
@@ -537,6 +539,40 @@ def _check_kernel_id(self, kernel_id):
537539
raise web.HTTPError(404, "Kernel does not exist: %s" % kernel_id)
538540

539541
# monitoring activity:
542+
untracked_message_types = List(
543+
trait=Unicode(),
544+
config=True,
545+
default_value=[
546+
"comm_info_request",
547+
"comm_info_reply",
548+
"kernel_info_request",
549+
"kernel_info_reply",
550+
"shutdown_request",
551+
"shutdown_reply",
552+
"interrupt_request",
553+
"interrupt_reply",
554+
"debug_request",
555+
"debug_reply",
556+
"stream",
557+
"display_data",
558+
"update_display_data",
559+
"execute_input",
560+
"execute_result",
561+
"error",
562+
"status",
563+
"clear_output",
564+
"debug_event",
565+
"input_request",
566+
"input_reply",
567+
],
568+
help="""List of kernel message types excluded from user activity tracking.
569+
570+
This should be a superset of the message types sent on any channel other
571+
than the shell channel.""",
572+
)
573+
574+
def track_message_type(self, message_type):
575+
return message_type not in self.untracked_message_types
540576

541577
def start_watching_activity(self, kernel_id):
542578
"""Start watching IOPub messages on a kernel for activity.
@@ -557,15 +593,27 @@ def start_watching_activity(self, kernel_id):
557593

558594
def record_activity(msg_list):
559595
"""Record an IOPub message arriving from a kernel"""
560-
self.last_kernel_activity = kernel.last_activity = utcnow()
561-
562596
idents, fed_msg_list = session.feed_identities(msg_list)
563597
msg = session.deserialize(fed_msg_list, content=False)
564598

565599
msg_type = msg["header"]["msg_type"]
600+
parent_msg_type = msg.get("parent_header", {}).get("msg_type", None)
601+
if (
602+
self.track_message_type(msg_type)
603+
or self.track_message_type(parent_msg_type)
604+
or kernel.execution_state == "busy"
605+
):
606+
self.last_kernel_activity = kernel.last_activity = utcnow()
566607
if msg_type == "status":
567608
msg = session.deserialize(fed_msg_list)
568-
kernel.execution_state = msg["content"]["execution_state"]
609+
execution_state = msg["content"]["execution_state"]
610+
if self.track_message_type(parent_msg_type):
611+
kernel.execution_state = execution_state
612+
elif kernel.execution_state == "starting" and execution_state != "starting":
613+
# We always normalize post-starting execution state to "idle"
614+
# unless we know that the status is in response to one of our
615+
# tracked message types.
616+
kernel.execution_state = "idle"
569617
self.log.debug(
570618
"activity on %s: %s (%s)",
571619
kernel_id,

tests/services/kernels/test_cull.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
2+
import datetime
23
import json
34
import os
45
import platform
6+
import uuid
57
import warnings
68

79
import jupyter_client
@@ -94,6 +96,83 @@ async def test_cull_idle(jp_fetch, jp_ws_fetch):
9496
assert culled
9597

9698

99+
@pytest.mark.parametrize(
100+
"jp_server_config",
101+
[
102+
# Test the synchronous case
103+
Config(
104+
{
105+
"ServerApp": {
106+
"kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.MappingKernelManager",
107+
"MappingKernelManager": {
108+
"cull_idle_timeout": CULL_TIMEOUT,
109+
"cull_interval": CULL_INTERVAL,
110+
"cull_connected": True,
111+
},
112+
}
113+
}
114+
),
115+
# Test the async case
116+
Config(
117+
{
118+
"ServerApp": {
119+
"kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.AsyncMappingKernelManager",
120+
"AsyncMappingKernelManager": {
121+
"cull_idle_timeout": CULL_TIMEOUT,
122+
"cull_interval": CULL_INTERVAL,
123+
"cull_connected": True,
124+
},
125+
}
126+
}
127+
),
128+
],
129+
)
130+
async def test_cull_connected(jp_fetch, jp_ws_fetch):
131+
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
132+
kernel = json.loads(r.body.decode())
133+
kid = kernel["id"]
134+
135+
# Open a websocket connection.
136+
ws = await jp_ws_fetch("api", "kernels", kid, "channels")
137+
session_id = uuid.uuid1().hex
138+
message_id = uuid.uuid1().hex
139+
await ws.write_message(
140+
json.dumps(
141+
{
142+
"channel": "shell",
143+
"header": {
144+
"date": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(),
145+
"session": session_id,
146+
"msg_id": message_id,
147+
"msg_type": "execute_request",
148+
"username": "",
149+
"version": "5.2",
150+
},
151+
"parent_header": {},
152+
"metadata": {},
153+
"content": {
154+
"code": f"import time\ntime.sleep({CULL_TIMEOUT-1})",
155+
"silent": False,
156+
"allow_stdin": False,
157+
"stop_on_error": True,
158+
},
159+
"buffers": [],
160+
}
161+
)
162+
)
163+
164+
r = await jp_fetch("api", "kernels", kid, method="GET")
165+
model = json.loads(r.body.decode())
166+
assert model["connections"] == 1
167+
culled = await get_cull_status(
168+
kid, jp_fetch
169+
) # connected, but code cell still running. Should not be culled
170+
assert not culled
171+
culled = await get_cull_status(kid, jp_fetch) # still connected, but idle... should be culled
172+
assert culled
173+
ws.close()
174+
175+
97176
async def test_cull_idle_disable(jp_fetch, jp_ws_fetch, jp_kernelspec_with_metadata):
98177
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
99178
kernel = json.loads(r.body.decode())
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import asyncio
2+
import datetime
3+
import json
4+
import os
5+
import platform
6+
import time
7+
import uuid
8+
import warnings
9+
10+
import jupyter_client
11+
import pytest
12+
from flaky import flaky
13+
from tornado.httpclient import HTTPClientError
14+
from traitlets.config import Config
15+
16+
MAX_POLL_ATTEMPTS = 10
17+
POLL_INTERVAL = 1
18+
MINIMUM_CONSISTENT_COUNT = 4
19+
20+
21+
@flaky
22+
async def test_execution_state(jp_fetch, jp_ws_fetch):
23+
r = await jp_fetch("api", "kernels", method="POST", allow_nonstandard_methods=True)
24+
kernel = json.loads(r.body.decode())
25+
kid = kernel["id"]
26+
27+
# Open a websocket connection.
28+
ws = await jp_ws_fetch("api", "kernels", kid, "channels")
29+
session_id = uuid.uuid1().hex
30+
message_id = uuid.uuid1().hex
31+
await ws.write_message(
32+
json.dumps(
33+
{
34+
"channel": "shell",
35+
"header": {
36+
"date": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(),
37+
"session": session_id,
38+
"msg_id": message_id,
39+
"msg_type": "execute_request",
40+
"username": "",
41+
"version": "5.2",
42+
},
43+
"parent_header": {},
44+
"metadata": {},
45+
"content": {
46+
"code": "while True:\n\tpass",
47+
"silent": False,
48+
"allow_stdin": False,
49+
"stop_on_error": True,
50+
},
51+
"buffers": [],
52+
}
53+
)
54+
)
55+
await poll_for_parent_message_status(kid, message_id, "busy", ws)
56+
es = await get_execution_state(kid, jp_fetch)
57+
assert es == "busy"
58+
59+
message_id_2 = uuid.uuid1().hex
60+
await ws.write_message(
61+
json.dumps(
62+
{
63+
"channel": "control",
64+
"header": {
65+
"date": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(),
66+
"session": session_id,
67+
"msg_id": message_id_2,
68+
"msg_type": "debug_request",
69+
"username": "",
70+
"version": "5.2",
71+
},
72+
"parent_header": {},
73+
"metadata": {},
74+
"content": {
75+
"type": "request",
76+
"command": "debugInfo",
77+
},
78+
"buffers": [],
79+
}
80+
)
81+
)
82+
await poll_for_parent_message_status(kid, message_id_2, "idle", ws)
83+
es = await get_execution_state(kid, jp_fetch)
84+
85+
# Verify that the overall kernel status is still "busy" even though one
86+
# "idle" response was already seen for the second execute request.
87+
assert es == "busy"
88+
89+
await jp_fetch(
90+
"api",
91+
"kernels",
92+
kid,
93+
"interrupt",
94+
method="POST",
95+
allow_nonstandard_methods=True,
96+
)
97+
98+
await poll_for_parent_message_status(kid, message_id, "idle", ws)
99+
es = await get_execution_state(kid, jp_fetch)
100+
assert es == "idle"
101+
ws.close()
102+
103+
104+
async def get_execution_state(kid, jp_fetch):
105+
# There is an inherent race condition when getting the kernel execution status
106+
# where we might fetch the status right before an expected state change occurs.
107+
#
108+
# To work-around this, we don't return the status until we've been able to fetch
109+
# it twice in a row and get the same result both times.
110+
last_execution_states = []
111+
112+
for _ in range(MAX_POLL_ATTEMPTS):
113+
r = await jp_fetch("api", "kernels", kid, method="GET")
114+
model = json.loads(r.body.decode())
115+
execution_state = model["execution_state"]
116+
last_execution_states.append(execution_state)
117+
consistent_count = 0
118+
last_execution_state = None
119+
for es in last_execution_states:
120+
if es != last_execution_state:
121+
consistent_count = 0
122+
last_execution_state = es
123+
consistent_count += 1
124+
if consistent_count >= MINIMUM_CONSISTENT_COUNT:
125+
return es
126+
time.sleep(POLL_INTERVAL)
127+
128+
raise AssertionError("failed to get a consistent execution state")
129+
130+
131+
async def poll_for_parent_message_status(kid, parent_message_id, target_status, ws):
132+
while True:
133+
resp = await ws.read_message()
134+
resp_json = json.loads(resp)
135+
print(resp_json)
136+
parent_message = resp_json.get("parent_header", {}).get("msg_id", None)
137+
if parent_message != parent_message_id:
138+
continue
139+
140+
response_type = resp_json.get("header", {}).get("msg_type", None)
141+
if response_type != "status":
142+
continue
143+
144+
execution_state = resp_json.get("content", {}).get("execution_state", "")
145+
if execution_state == target_status:
146+
return

0 commit comments

Comments
 (0)