Skip to content

Commit ad7e578

Browse files
dpwatrousRena Chen
authored andcommitted
Async test helpers now preserve type info
1 parent 01596fa commit ad7e578

File tree

4 files changed

+155
-161
lines changed

4 files changed

+155
-161
lines changed
Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from collections.abc import AsyncIterable
22
import inspect
3+
from typing import Any, cast, Coroutine, Iterable, TypeVar, Union
34

5+
T = TypeVar("T")
46

5-
# wrapper to handle async and sync objects
6-
async def async_wrapper(obj):
7+
async def wrap_result(result: Union[T, Coroutine[Any, Any, T]]) -> T:
8+
"""Handle an non-list operation result and await it if it's a coroutine"""
9+
if inspect.iscoroutine(result):
10+
result = await result
11+
return await wrap_result(result)
12+
return cast(T, result)
713

8-
if isinstance(obj, AsyncIterable):
9-
items = []
10-
async for item in obj:
14+
async def wrap_list_result(result: Union[Iterable[T], AsyncIterable[T]]) -> Iterable[T]:
15+
"""Handle a list operation result and convert to a list if it's an AsyncIterable"""
16+
if isinstance(result, AsyncIterable):
17+
items: Iterable[T] = []
18+
async for item in result:
1119
items.append(item)
1220
return items
13-
14-
if inspect.iscoroutine(obj):
15-
waited = await obj
16-
# wrap again to handle nested coroutines or async generators
17-
return await async_wrapper(waited)
18-
19-
return obj
21+
22+
return result

sdk/batch/azure-batch/tests/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
stop_record_or_playback,
2020
)
2121

22-
from async_wrapper import async_wrapper
22+
from async_wrapper import wrap_result
2323

2424

2525
# A modified version of devtools_testutils.aio.recorded_by_proxy_async
@@ -149,6 +149,6 @@ async def wrapper(self, BatchClient, **kwargs):
149149
except Exception as err:
150150
raise err
151151
finally:
152-
await async_wrapper(client.close())
152+
await wrap_result(client.close())
153153

154154
return wrapper
Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from async_wrapper import async_wrapper
1+
from async_wrapper import wrap_list_result, wrap_result
22
import pytest
33

44

@@ -8,7 +8,7 @@ async def test_iscoroutine(self):
88
async def func():
99
return 1
1010

11-
result = await async_wrapper(func())
11+
result = await wrap_result(func())
1212
assert result == 1
1313

1414
@pytest.mark.asyncio
@@ -17,19 +17,7 @@ async def func():
1717
for i in range(3):
1818
yield i
1919

20-
result = await async_wrapper(func())
21-
assert result == [0, 1, 2]
22-
23-
@pytest.mark.asyncio
24-
async def test_isNestedAsyncIterable(self):
25-
async def func():
26-
async def nested():
27-
for i in range(3):
28-
yield i
29-
30-
return nested()
31-
32-
result = await async_wrapper(func())
20+
result = await wrap_list_result(func())
3321
assert result == [0, 1, 2]
3422

3523
@pytest.mark.asyncio
@@ -40,7 +28,7 @@ async def nested():
4028

4129
return nested()
4230

43-
result = await async_wrapper(func())
31+
result = await wrap_result(func())
4432
assert result == 2
4533

4634
@pytest.mark.asyncio
@@ -50,12 +38,12 @@ def func():
5038
yield i
5139

5240
iterable = func()
53-
result = await async_wrapper(iterable)
41+
result = await wrap_list_result(iterable)
5442
assert result == iterable
5543

5644
@pytest.mark.asyncio
5745
async def test_isSync(self):
5846
def func():
5947
return 1
6048

61-
assert await async_wrapper(func()) == 1
49+
assert await wrap_result(func()) == 1

0 commit comments

Comments
 (0)