1
- import asyncio
1
+ import functools
2
2
from contextlib import (
3
3
asynccontextmanager , AbstractAsyncContextManager ,
4
4
AsyncExitStack , nullcontext , aclosing , contextmanager )
8
8
9
9
from test .test_contextlib import TestBaseExitStack
10
10
11
- support .requires_working_socket (module = True )
12
11
13
- def tearDownModule ():
14
- asyncio ._set_event_loop_policy (None )
12
+ def _run_async_fn (async_fn , / , * args , ** kwargs ):
13
+ coro = async_fn (* args , ** kwargs )
14
+ gen = type (coro ).__await__ (coro )
15
+ try :
16
+ gen .send (None )
17
+ except StopIteration as e :
18
+ return e .value
19
+ else :
20
+ raise AssertionError ("coroutine did not stop" )
21
+ finally :
22
+ gen .close ()
15
23
16
24
17
- class TestAbstractAsyncContextManager (unittest .IsolatedAsyncioTestCase ):
25
+ def _async_test (async_fn ):
26
+ """Decorator to turn an async function into a test case."""
27
+ @functools .wraps (async_fn )
28
+ def wrapper (* args , ** kwargs ):
29
+ return _run_async_fn (async_fn , * args , ** kwargs )
18
30
31
+ return wrapper
32
+
33
+
34
+ class TestAbstractAsyncContextManager (unittest .TestCase ):
35
+
36
+ @_async_test
19
37
async def test_enter (self ):
20
38
class DefaultEnter (AbstractAsyncContextManager ):
21
39
async def __aexit__ (self , * args ):
@@ -27,6 +45,7 @@ async def __aexit__(self, *args):
27
45
async with manager as context :
28
46
self .assertIs (manager , context )
29
47
48
+ @_async_test
30
49
async def test_slots (self ):
31
50
class DefaultAsyncContextManager (AbstractAsyncContextManager ):
32
51
__slots__ = ()
@@ -38,6 +57,7 @@ async def __aexit__(self, *args):
38
57
manager = DefaultAsyncContextManager ()
39
58
manager .var = 42
40
59
60
+ @_async_test
41
61
async def test_async_gen_propagates_generator_exit (self ):
42
62
# A regression test for https://bugs.python.org/issue33786.
43
63
@@ -88,8 +108,9 @@ class NoneAexit(ManagerFromScratch):
88
108
self .assertFalse (issubclass (NoneAexit , AbstractAsyncContextManager ))
89
109
90
110
91
- class AsyncContextManagerTestCase (unittest .IsolatedAsyncioTestCase ):
111
+ class AsyncContextManagerTestCase (unittest .TestCase ):
92
112
113
+ @_async_test
93
114
async def test_contextmanager_plain (self ):
94
115
state = []
95
116
@asynccontextmanager
@@ -103,6 +124,7 @@ async def woohoo():
103
124
state .append (x )
104
125
self .assertEqual (state , [1 , 42 , 999 ])
105
126
127
+ @_async_test
106
128
async def test_contextmanager_finally (self ):
107
129
state = []
108
130
@asynccontextmanager
@@ -120,6 +142,7 @@ async def woohoo():
120
142
raise ZeroDivisionError ()
121
143
self .assertEqual (state , [1 , 42 , 999 ])
122
144
145
+ @_async_test
123
146
async def test_contextmanager_traceback (self ):
124
147
@asynccontextmanager
125
148
async def f ():
@@ -175,6 +198,7 @@ class StopAsyncIterationSubclass(StopAsyncIteration):
175
198
self .assertEqual (frames [0 ].name , 'test_contextmanager_traceback' )
176
199
self .assertEqual (frames [0 ].line , 'raise stop_exc' )
177
200
201
+ @_async_test
178
202
async def test_contextmanager_no_reraise (self ):
179
203
@asynccontextmanager
180
204
async def whee ():
@@ -184,6 +208,7 @@ async def whee():
184
208
# Calling __aexit__ should not result in an exception
185
209
self .assertFalse (await ctx .__aexit__ (TypeError , TypeError ("foo" ), None ))
186
210
211
+ @_async_test
187
212
async def test_contextmanager_trap_yield_after_throw (self ):
188
213
@asynccontextmanager
189
214
async def whoo ():
@@ -199,6 +224,7 @@ async def whoo():
199
224
# The "gen" attribute is an implementation detail.
200
225
self .assertFalse (ctx .gen .ag_suspended )
201
226
227
+ @_async_test
202
228
async def test_contextmanager_trap_no_yield (self ):
203
229
@asynccontextmanager
204
230
async def whoo ():
@@ -208,6 +234,7 @@ async def whoo():
208
234
with self .assertRaises (RuntimeError ):
209
235
await ctx .__aenter__ ()
210
236
237
+ @_async_test
211
238
async def test_contextmanager_trap_second_yield (self ):
212
239
@asynccontextmanager
213
240
async def whoo ():
@@ -221,6 +248,7 @@ async def whoo():
221
248
# The "gen" attribute is an implementation detail.
222
249
self .assertFalse (ctx .gen .ag_suspended )
223
250
251
+ @_async_test
224
252
async def test_contextmanager_non_normalised (self ):
225
253
@asynccontextmanager
226
254
async def whoo ():
@@ -234,6 +262,7 @@ async def whoo():
234
262
with self .assertRaises (SyntaxError ):
235
263
await ctx .__aexit__ (RuntimeError , None , None )
236
264
265
+ @_async_test
237
266
async def test_contextmanager_except (self ):
238
267
state = []
239
268
@asynccontextmanager
@@ -251,6 +280,7 @@ async def woohoo():
251
280
raise ZeroDivisionError (999 )
252
281
self .assertEqual (state , [1 , 42 , 999 ])
253
282
283
+ @_async_test
254
284
async def test_contextmanager_except_stopiter (self ):
255
285
@asynccontextmanager
256
286
async def woohoo ():
@@ -277,6 +307,7 @@ class StopAsyncIterationSubclass(StopAsyncIteration):
277
307
else :
278
308
self .fail (f'{ stop_exc } was suppressed' )
279
309
310
+ @_async_test
280
311
async def test_contextmanager_wrap_runtimeerror (self ):
281
312
@asynccontextmanager
282
313
async def woohoo ():
@@ -321,12 +352,14 @@ def test_contextmanager_doc_attrib(self):
321
352
self .assertEqual (baz .__doc__ , "Whee!" )
322
353
323
354
@support .requires_docstrings
355
+ @_async_test
324
356
async def test_instance_docstring_given_cm_docstring (self ):
325
357
baz = self ._create_contextmanager_attribs ()(None )
326
358
self .assertEqual (baz .__doc__ , "Whee!" )
327
359
async with baz :
328
360
pass # suppress warning
329
361
362
+ @_async_test
330
363
async def test_keywords (self ):
331
364
# Ensure no keyword arguments are inhibited
332
365
@asynccontextmanager
@@ -335,6 +368,7 @@ async def woohoo(self, func, args, kwds):
335
368
async with woohoo (self = 11 , func = 22 , args = 33 , kwds = 44 ) as target :
336
369
self .assertEqual (target , (11 , 22 , 33 , 44 ))
337
370
371
+ @_async_test
338
372
async def test_recursive (self ):
339
373
depth = 0
340
374
ncols = 0
@@ -361,6 +395,7 @@ async def recursive():
361
395
self .assertEqual (ncols , 10 )
362
396
self .assertEqual (depth , 0 )
363
397
398
+ @_async_test
364
399
async def test_decorator (self ):
365
400
entered = False
366
401
@@ -379,6 +414,7 @@ async def test():
379
414
await test ()
380
415
self .assertFalse (entered )
381
416
417
+ @_async_test
382
418
async def test_decorator_with_exception (self ):
383
419
entered = False
384
420
@@ -401,6 +437,7 @@ async def test():
401
437
await test ()
402
438
self .assertFalse (entered )
403
439
440
+ @_async_test
404
441
async def test_decorating_method (self ):
405
442
406
443
@asynccontextmanager
@@ -435,14 +472,15 @@ async def method(self, a, b, c=None):
435
472
self .assertEqual (test .b , 2 )
436
473
437
474
438
- class AclosingTestCase (unittest .IsolatedAsyncioTestCase ):
475
+ class AclosingTestCase (unittest .TestCase ):
439
476
440
477
@support .requires_docstrings
441
478
def test_instance_docs (self ):
442
479
cm_docstring = aclosing .__doc__
443
480
obj = aclosing (None )
444
481
self .assertEqual (obj .__doc__ , cm_docstring )
445
482
483
+ @_async_test
446
484
async def test_aclosing (self ):
447
485
state = []
448
486
class C :
@@ -454,6 +492,7 @@ async def aclose(self):
454
492
self .assertEqual (x , y )
455
493
self .assertEqual (state , [1 ])
456
494
495
+ @_async_test
457
496
async def test_aclosing_error (self ):
458
497
state = []
459
498
class C :
@@ -467,6 +506,7 @@ async def aclose(self):
467
506
1 / 0
468
507
self .assertEqual (state , [1 ])
469
508
509
+ @_async_test
470
510
async def test_aclosing_bpo41229 (self ):
471
511
state = []
472
512
@@ -492,45 +532,27 @@ async def agenfunc():
492
532
self .assertEqual (state , [1 ])
493
533
494
534
495
- class TestAsyncExitStack (TestBaseExitStack , unittest .IsolatedAsyncioTestCase ):
535
+ class TestAsyncExitStack (TestBaseExitStack , unittest .TestCase ):
496
536
class SyncAsyncExitStack (AsyncExitStack ):
497
- @staticmethod
498
- def run_coroutine (coro ):
499
- loop = asyncio .new_event_loop ()
500
- t = loop .create_task (coro )
501
- t .add_done_callback (lambda f : loop .stop ())
502
- loop .run_forever ()
503
-
504
- exc = t .exception ()
505
- if not exc :
506
- return t .result ()
507
- else :
508
- context = exc .__context__
509
-
510
- try :
511
- raise exc
512
- except :
513
- exc .__context__ = context
514
- raise exc
515
537
516
538
def close (self ):
517
- return self . run_coroutine (self .aclose () )
539
+ return _run_async_fn (self .aclose )
518
540
519
541
def __enter__ (self ):
520
- return self . run_coroutine (self .__aenter__ () )
542
+ return _run_async_fn (self .__aenter__ )
521
543
522
544
def __exit__ (self , * exc_details ):
523
- return self . run_coroutine (self .__aexit__ ( * exc_details ) )
545
+ return _run_async_fn (self .__aexit__ , * exc_details )
524
546
525
547
exit_stack = SyncAsyncExitStack
526
548
callback_error_internal_frames = [
527
- ('__exit__' , 'return self.run_coroutine(self.__aexit__(*exc_details))' ),
528
- ('run_coroutine' , 'raise exc' ),
529
- ('run_coroutine' , 'raise exc' ),
549
+ ('__exit__' , 'return _run_async_fn(self.__aexit__, *exc_details)' ),
550
+ ('_run_async_fn' , 'gen.send(None)' ),
530
551
('__aexit__' , 'raise exc' ),
531
552
('__aexit__' , 'cb_suppress = cb(*exc_details)' ),
532
553
]
533
554
555
+ @_async_test
534
556
async def test_async_callback (self ):
535
557
expected = [
536
558
((), {}),
@@ -573,6 +595,7 @@ async def _exit(*args, **kwds):
573
595
stack .push_async_callback (callback = _exit , arg = 3 )
574
596
self .assertEqual (result , [])
575
597
598
+ @_async_test
576
599
async def test_async_push (self ):
577
600
exc_raised = ZeroDivisionError
578
601
async def _expect_exc (exc_type , exc , exc_tb ):
@@ -608,6 +631,7 @@ async def __aexit__(self, *exc_details):
608
631
self .assertIs (stack ._exit_callbacks [- 1 ][1 ], _expect_exc )
609
632
1 / 0
610
633
634
+ @_async_test
611
635
async def test_enter_async_context (self ):
612
636
class TestCM (object ):
613
637
async def __aenter__ (self ):
@@ -629,6 +653,7 @@ async def _exit():
629
653
630
654
self .assertEqual (result , [1 , 2 , 3 , 4 ])
631
655
656
+ @_async_test
632
657
async def test_enter_async_context_errors (self ):
633
658
class LacksEnterAndExit :
634
659
pass
@@ -648,6 +673,7 @@ async def __aenter__(self):
648
673
await stack .enter_async_context (LacksExit ())
649
674
self .assertFalse (stack ._exit_callbacks )
650
675
676
+ @_async_test
651
677
async def test_async_exit_exception_chaining (self ):
652
678
# Ensure exception chaining matches the reference behaviour
653
679
async def raise_exc (exc ):
@@ -679,6 +705,7 @@ async def suppress_exc(*exc_details):
679
705
self .assertIsInstance (inner_exc , ValueError )
680
706
self .assertIsInstance (inner_exc .__context__ , ZeroDivisionError )
681
707
708
+ @_async_test
682
709
async def test_async_exit_exception_explicit_none_context (self ):
683
710
# Ensure AsyncExitStack chaining matches actual nested `with` statements
684
711
# regarding explicit __context__ = None.
@@ -713,6 +740,7 @@ async def my_cm_with_exit_stack():
713
740
else :
714
741
self .fail ("Expected IndexError, but no exception was raised" )
715
742
743
+ @_async_test
716
744
async def test_instance_bypass_async (self ):
717
745
class Example (object ): pass
718
746
cm = Example ()
@@ -725,7 +753,8 @@ class Example(object): pass
725
753
self .assertIs (stack ._exit_callbacks [- 1 ][1 ], cm )
726
754
727
755
728
- class TestAsyncNullcontext (unittest .IsolatedAsyncioTestCase ):
756
+ class TestAsyncNullcontext (unittest .TestCase ):
757
+ @_async_test
729
758
async def test_async_nullcontext (self ):
730
759
class C :
731
760
pass
0 commit comments