|
1 |
| -import asyncio |
2 | 1 | from contextlib import (
|
3 | 2 | asynccontextmanager, AbstractAsyncContextManager,
|
4 | 3 | AsyncExitStack, nullcontext, aclosing, contextmanager)
|
|
8 | 7 |
|
9 | 8 | from test.test_contextlib import TestBaseExitStack
|
10 | 9 |
|
11 |
| -support.requires_working_socket(module=True) |
12 | 10 |
|
13 |
| -def _async_test(func): |
| 11 | +def _run_async_fn(async_fn, /, *args, **kwargs): |
| 12 | + coro = async_fn(*args, **kwargs) |
| 13 | + gen = type(coro).__await__(coro) |
| 14 | + try: |
| 15 | + gen.send(None) |
| 16 | + except StopIteration as e: |
| 17 | + return e.value |
| 18 | + else: |
| 19 | + raise AssertionError("coroutine did not stop") |
| 20 | + finally: |
| 21 | + gen.close() |
| 22 | + |
| 23 | + |
| 24 | +def _async_test(async_fn): |
14 | 25 | """Decorator to turn an async function into a test case."""
|
15 |
| - @functools.wraps(func) |
| 26 | + @functools.wraps(async_fn) |
16 | 27 | def wrapper(*args, **kwargs):
|
17 |
| - coro = func(*args, **kwargs) |
18 |
| - asyncio.run(coro) |
19 |
| - return wrapper |
| 28 | + return _run_async_fn(async_fn, *args, **kwargs) |
20 | 29 |
|
21 |
| -def tearDownModule(): |
22 |
| - asyncio.set_event_loop_policy(None) |
| 30 | + return wrapper |
23 | 31 |
|
24 | 32 |
|
25 | 33 | class TestAbstractAsyncContextManager(unittest.TestCase):
|
26 | 34 |
|
| 35 | + def test_async_test_self_test(self): |
| 36 | + class _async_yield: |
| 37 | + def __init__(self, v): |
| 38 | + self.v = v |
| 39 | + |
| 40 | + def __await__(self): |
| 41 | + return (yield self.v) |
| 42 | + |
| 43 | + @_async_test |
| 44 | + async def do_not_stop_coro(): |
| 45 | + while True: |
| 46 | + await _async_yield(None) |
| 47 | + |
| 48 | + with self.assertRaisesRegex(AssertionError, "coroutine did not stop"): |
| 49 | + do_not_stop_coro() |
| 50 | + |
27 | 51 | @_async_test
|
28 | 52 | async def test_enter(self):
|
29 | 53 | class DefaultEnter(AbstractAsyncContextManager):
|
@@ -455,49 +479,24 @@ async def agenfunc():
|
455 | 479 |
|
456 | 480 | class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
|
457 | 481 | class SyncAsyncExitStack(AsyncExitStack):
|
458 |
| - @staticmethod |
459 |
| - def run_coroutine(coro): |
460 |
| - loop = asyncio.get_event_loop_policy().get_event_loop() |
461 |
| - t = loop.create_task(coro) |
462 |
| - t.add_done_callback(lambda f: loop.stop()) |
463 |
| - loop.run_forever() |
464 |
| - |
465 |
| - exc = t.exception() |
466 |
| - if not exc: |
467 |
| - return t.result() |
468 |
| - else: |
469 |
| - context = exc.__context__ |
470 |
| - |
471 |
| - try: |
472 |
| - raise exc |
473 |
| - except: |
474 |
| - exc.__context__ = context |
475 |
| - raise exc |
476 | 482 |
|
477 | 483 | def close(self):
|
478 |
| - return self.run_coroutine(self.aclose()) |
| 484 | + return _run_async_fn(self.aclose) |
479 | 485 |
|
480 | 486 | def __enter__(self):
|
481 |
| - return self.run_coroutine(self.__aenter__()) |
| 487 | + return _run_async_fn(self.__aenter__) |
482 | 488 |
|
483 | 489 | def __exit__(self, *exc_details):
|
484 |
| - return self.run_coroutine(self.__aexit__(*exc_details)) |
| 490 | + return _run_async_fn(self.__aexit__, *exc_details) |
485 | 491 |
|
486 | 492 | exit_stack = SyncAsyncExitStack
|
487 | 493 | callback_error_internal_frames = [
|
488 |
| - ('__exit__', 'return self.run_coroutine(self.__aexit__(*exc_details))'), |
489 |
| - ('run_coroutine', 'raise exc'), |
490 |
| - ('run_coroutine', 'raise exc'), |
| 494 | + ('__exit__', 'return _run_async_fn(self.__aexit__, *exc_details)'), |
| 495 | + ('_run_async_fn', 'gen.send(None)'), |
491 | 496 | ('__aexit__', 'raise exc_details[1]'),
|
492 | 497 | ('__aexit__', 'cb_suppress = cb(*exc_details)'),
|
493 | 498 | ]
|
494 | 499 |
|
495 |
| - def setUp(self): |
496 |
| - self.loop = asyncio.new_event_loop() |
497 |
| - asyncio.set_event_loop(self.loop) |
498 |
| - self.addCleanup(self.loop.close) |
499 |
| - self.addCleanup(asyncio.set_event_loop_policy, None) |
500 |
| - |
501 | 500 | @_async_test
|
502 | 501 | async def test_async_callback(self):
|
503 | 502 | expected = [
|
|
0 commit comments