1
1
from __future__ import annotations
2
2
3
+ import sys
3
4
from typing import TYPE_CHECKING , Union
4
5
5
6
import pytest
6
7
7
8
import trio
8
- from trio import EndOfChannel , background_with_channel , open_memory_channel
9
+ from trio import EndOfChannel , as_safe_channel , open_memory_channel
9
10
10
11
from ..testing import Matcher , RaisesGroup , assert_checkpoints , wait_all_tasks_blocked
11
12
13
+ if sys .version_info < (3 , 11 ):
14
+ from exceptiongroup import ExceptionGroup
15
+
12
16
if TYPE_CHECKING :
13
17
from collections .abc import AsyncGenerator
14
18
@@ -416,8 +420,8 @@ async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
416
420
r .receive_nowait ()
417
421
418
422
419
- async def test_background_with_channel_exhaust () -> None :
420
- @background_with_channel
423
+ async def test_as_safe_channel_exhaust () -> None :
424
+ @as_safe_channel
421
425
async def agen () -> AsyncGenerator [int ]:
422
426
yield 1
423
427
@@ -426,8 +430,8 @@ async def agen() -> AsyncGenerator[int]:
426
430
assert x == 1
427
431
428
432
429
- async def test_background_with_channel_broken_resource () -> None :
430
- @background_with_channel
433
+ async def test_as_safe_channel_broken_resource () -> None :
434
+ @as_safe_channel
431
435
async def agen () -> AsyncGenerator [int ]:
432
436
yield 1
433
437
yield 2
@@ -445,10 +449,10 @@ async def agen() -> AsyncGenerator[int]:
445
449
# but we don't get an error on exit of the cm
446
450
447
451
448
- async def test_background_with_channel_cancelled () -> None :
452
+ async def test_as_safe_channel_cancelled () -> None :
449
453
with trio .CancelScope () as cs :
450
454
451
- @background_with_channel
455
+ @as_safe_channel
452
456
async def agen () -> AsyncGenerator [None ]: # pragma: no cover
453
457
raise AssertionError (
454
458
"cancel before consumption means generator should not be iterated"
@@ -459,12 +463,12 @@ async def agen() -> AsyncGenerator[None]: # pragma: no cover
459
463
cs .cancel ()
460
464
461
465
462
- async def test_background_with_channel_recv_closed (
466
+ async def test_as_safe_channel_recv_closed (
463
467
autojump_clock : trio .testing .MockClock ,
464
468
) -> None :
465
469
event = trio .Event ()
466
470
467
- @background_with_channel
471
+ @as_safe_channel
468
472
async def agen () -> AsyncGenerator [int ]:
469
473
await event .wait ()
470
474
yield 1
@@ -476,10 +480,10 @@ async def agen() -> AsyncGenerator[int]:
476
480
await trio .sleep (1 )
477
481
478
482
479
- async def test_background_with_channel_no_race () -> None :
483
+ async def test_as_safe_channel_no_race () -> None :
480
484
# this previously led to a race condition due to
481
485
# https://github.com/python-trio/trio/issues/1559
482
- @background_with_channel
486
+ @as_safe_channel
483
487
async def agen () -> AsyncGenerator [int ]:
484
488
yield 1
485
489
raise ValueError ("oae" )
@@ -490,10 +494,10 @@ async def agen() -> AsyncGenerator[int]:
490
494
assert x == 1
491
495
492
496
493
- async def test_background_with_channel_buffer_size_too_small (
497
+ async def test_as_safe_channel_buffer_size_too_small (
494
498
autojump_clock : trio .testing .MockClock ,
495
499
) -> None :
496
- @background_with_channel
500
+ @as_safe_channel
497
501
async def agen () -> AsyncGenerator [int ]:
498
502
yield 1
499
503
raise AssertionError (
@@ -507,8 +511,8 @@ async def agen() -> AsyncGenerator[int]:
507
511
await trio .sleep_forever ()
508
512
509
513
510
- async def test_background_with_channel_no_interleave () -> None :
511
- @background_with_channel
514
+ async def test_as_safe_channel_no_interleave () -> None :
515
+ @as_safe_channel
512
516
async def agen () -> AsyncGenerator [int ]:
513
517
yield 1
514
518
raise AssertionError # pragma: no cover
@@ -518,10 +522,10 @@ async def agen() -> AsyncGenerator[int]:
518
522
await trio .lowlevel .checkpoint ()
519
523
520
524
521
- async def test_background_with_channel_genexit_finally () -> None :
525
+ async def test_as_safe_channel_genexit_finally () -> None :
522
526
events : list [str ] = []
523
527
524
- @background_with_channel
528
+ @as_safe_channel
525
529
async def agen (stuff : list [str ]) -> AsyncGenerator [int ]:
526
530
try :
527
531
yield 1
@@ -532,24 +536,23 @@ async def agen(stuff: list[str]) -> AsyncGenerator[int]:
532
536
stuff .append ("finally" )
533
537
raise ValueError ("agen" )
534
538
535
- with pytest .raises (
536
- RuntimeError ,
537
- match = r"^Encountered exception during cleanup of generator object, as well as exception in the contextmanager body.$" ,
538
- ) as excinfo :
539
+ with RaisesGroup (
540
+ RaisesGroup (
541
+ Matcher (ValueError , match = "^agen$" ),
542
+ Matcher (TypeError , match = "^iterator$" ),
543
+ ),
544
+ match = r"^Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.$" ,
545
+ ):
539
546
async with agen (events ) as recv_chan :
540
547
async for i in recv_chan : # pragma: no branch
541
548
assert i == 1
542
549
raise TypeError ("iterator" )
543
550
544
551
assert events == ["GeneratorExit()" , "finally" ]
545
- RaisesGroup (
546
- Matcher (ValueError , match = "^agen$" ),
547
- Matcher (TypeError , match = "^iterator$" ),
548
- ).matches (excinfo .value .__cause__ )
549
552
550
553
551
- async def test_background_with_channel_nested_loop () -> None :
552
- @background_with_channel
554
+ async def test_as_safe_channel_nested_loop () -> None :
555
+ @as_safe_channel
553
556
async def agen () -> AsyncGenerator [int ]:
554
557
for i in range (2 ):
555
558
yield i
@@ -565,15 +568,49 @@ async def agen() -> AsyncGenerator[int]:
565
568
ii += 1
566
569
567
570
568
- async def test_doesnt_leak_cancellation () -> None :
569
- @background_with_channel
570
- async def agenfn () -> AsyncGenerator [None ]:
571
+ async def test_as_safe_channel_doesnt_leak_cancellation () -> None :
572
+ @as_safe_channel
573
+ async def agen () -> AsyncGenerator [None ]:
571
574
with trio .CancelScope () as cscope :
572
575
cscope .cancel ()
573
576
yield
574
577
575
578
with pytest .raises (AssertionError ):
576
- async with agenfn () as recv_chan :
579
+ async with agen () as recv_chan :
577
580
async for _ in recv_chan :
578
581
pass
579
582
raise AssertionError ("should be reachable" )
583
+
584
+
585
+ async def test_as_safe_channel_dont_unwrap_user_exceptiongroup () -> None :
586
+ @as_safe_channel
587
+ async def agen () -> AsyncGenerator [None ]:
588
+ yield
589
+
590
+ with RaisesGroup (Matcher (ValueError , match = "bar" ), match = "foo" ):
591
+ async with agen () as _ :
592
+ raise ExceptionGroup ("foo" , [ValueError ("bar" )])
593
+
594
+
595
+ async def test_as_safe_channel_multiple_receiver () -> None :
596
+ event = trio .Event ()
597
+
598
+ @as_safe_channel
599
+ async def agen () -> AsyncGenerator [int ]:
600
+ await event .wait ()
601
+ for i in range (2 ):
602
+ yield i
603
+
604
+ async def handle_value (
605
+ recv_chan : trio .abc .ReceiveChannel [int ],
606
+ value : int ,
607
+ task_status : trio .TaskStatus ,
608
+ ) -> None :
609
+ task_status .started ()
610
+ assert await recv_chan .receive () == value
611
+
612
+ async with agen () as recv_chan :
613
+ async with trio .open_nursery () as nursery :
614
+ await nursery .start (handle_value , recv_chan , 0 )
615
+ await nursery .start (handle_value , recv_chan , 1 )
616
+ event .set ()
0 commit comments