Skip to content

Commit b2669c5

Browse files
Make ChannelQueue.get_msg true async (#892)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dee8c7a commit b2669c5

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

jupyter_server/gateway/managers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) Jupyter Development Team.
22
# Distributed under the terms of the Modified BSD License.
3+
import asyncio
34
import datetime
45
import json
56
import os
67
from logging import Logger
7-
from queue import Queue
8+
from queue import Empty, Queue
89
from threading import Thread
10+
from time import monotonic
911
from typing import Any, Dict, Optional
1012

1113
import websocket
@@ -503,9 +505,24 @@ def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log:
503505
self.channel_socket = channel_socket
504506
self.log = log
505507

508+
async def _async_get(self, timeout=None):
509+
if timeout is None:
510+
timeout = float("inf")
511+
elif timeout < 0:
512+
raise ValueError("'timeout' must be a non-negative number")
513+
end_time = monotonic() + timeout
514+
515+
while True:
516+
try:
517+
return self.get(block=False)
518+
except Empty:
519+
if monotonic() > end_time:
520+
raise
521+
await asyncio.sleep(0)
522+
506523
async def get_msg(self, *args: Any, **kwargs: Any) -> dict:
507524
timeout = kwargs.get("timeout", 1)
508-
msg = self.get(timeout=timeout)
525+
msg = await self._async_get(timeout=timeout)
509526
self.log.debug(
510527
"Received message on channel: {}, msg_id: {}, msg_type: {}".format(
511528
self.channel_name, msg["msg_id"], msg["msg_type"] if msg else "null"

tests/test_gateway.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
"""Test GatewayClient"""
2+
import asyncio
23
import json
4+
import logging
35
import os
46
import uuid
57
from datetime import datetime
68
from io import BytesIO
7-
from unittest.mock import patch
9+
from queue import Empty
10+
from unittest.mock import MagicMock, patch
811

912
import pytest
1013
import tornado
1114
from tornado.httpclient import HTTPRequest, HTTPResponse
1215
from tornado.web import HTTPError
1316

14-
from jupyter_server.gateway.managers import GatewayClient
17+
from jupyter_server.gateway.managers import ChannelQueue, GatewayClient
1518
from jupyter_server.utils import ensure_async
1619

1720
from .utils import expected_http_error
@@ -318,6 +321,37 @@ async def test_gateway_shutdown(init_gateway, jp_serverapp, jp_fetch, missing_ke
318321
assert await is_kernel_running(jp_fetch, k2) is False
319322

320323

324+
async def test_channel_queue_get_msg_with_invalid_timeout():
325+
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
326+
327+
with pytest.raises(ValueError):
328+
await queue.get_msg(timeout=-1)
329+
330+
331+
async def test_channel_queue_get_msg_raises_empty_after_timeout():
332+
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
333+
334+
with pytest.raises(Empty):
335+
await asyncio.wait_for(queue.get_msg(timeout=0.1), 2)
336+
337+
338+
async def test_channel_queue_get_msg_without_timeout():
339+
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
340+
341+
with pytest.raises(asyncio.TimeoutError):
342+
await asyncio.wait_for(queue.get_msg(timeout=None), 1)
343+
344+
345+
async def test_channel_queue_get_msg_with_existing_item():
346+
sent_message = {"msg_id": 1, "msg_type": 2}
347+
queue = ChannelQueue("iopub", MagicMock(), logging.getLogger())
348+
queue.put_nowait(sent_message)
349+
350+
received_message = await asyncio.wait_for(queue.get_msg(timeout=None), 1)
351+
352+
assert received_message == sent_message
353+
354+
321355
#
322356
# Test methods below...
323357
#

0 commit comments

Comments
 (0)