Skip to content

Commit 046442d

Browse files
friedlisroach
authored andcommitted
bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)
1 parent e5d1f73 commit 046442d

File tree

5 files changed

+103
-42
lines changed

5 files changed

+103
-42
lines changed

Doc/library/unittest.mock.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ object::
873873
exception,
874874
- if ``side_effect`` is an iterable, the async function will return the
875875
next value of the iterable, however, if the sequence of result is
876-
exhausted, ``StopIteration`` is raised immediately,
876+
exhausted, ``StopAsyncIteration`` is raised immediately,
877877
- if ``side_effect`` is not defined, the async function will return the
878878
value defined by ``return_value``, hence, by default, the async function
879879
returns a new :class:`AsyncMock` object.

Lib/unittest/mock.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,8 +1139,8 @@ def _increment_mock_call(self, /, *args, **kwargs):
11391139
_new_parent = _new_parent._mock_new_parent
11401140

11411141
def _execute_mock_call(self, /, *args, **kwargs):
1142-
# seperate from _increment_mock_call so that awaited functions are
1143-
# executed seperately from their call
1142+
# separate from _increment_mock_call so that awaited functions are
1143+
# executed separately from their call, also AsyncMock overrides this method
11441144

11451145
effect = self.side_effect
11461146
if effect is not None:
@@ -2136,29 +2136,45 @@ def __init__(self, /, *args, **kwargs):
21362136
code_mock.co_flags = inspect.CO_COROUTINE
21372137
self.__dict__['__code__'] = code_mock
21382138

2139-
async def _mock_call(self, /, *args, **kwargs):
2140-
try:
2141-
result = super()._mock_call(*args, **kwargs)
2142-
except (BaseException, StopIteration) as e:
2143-
side_effect = self.side_effect
2144-
if side_effect is not None and not callable(side_effect):
2145-
raise
2146-
return await _raise(e)
2139+
async def _execute_mock_call(self, /, *args, **kwargs):
2140+
# This is nearly just like super(), except for sepcial handling
2141+
# of coroutines
21472142

21482143
_call = self.call_args
2144+
self.await_count += 1
2145+
self.await_args = _call
2146+
self.await_args_list.append(_call)
21492147

2150-
async def proxy():
2151-
try:
2152-
if inspect.isawaitable(result):
2153-
return await result
2154-
else:
2155-
return result
2156-
finally:
2157-
self.await_count += 1
2158-
self.await_args = _call
2159-
self.await_args_list.append(_call)
2148+
effect = self.side_effect
2149+
if effect is not None:
2150+
if _is_exception(effect):
2151+
raise effect
2152+
elif not _callable(effect):
2153+
try:
2154+
result = next(effect)
2155+
except StopIteration:
2156+
# It is impossible to propogate a StopIteration
2157+
# through coroutines because of PEP 479
2158+
raise StopAsyncIteration
2159+
if _is_exception(result):
2160+
raise result
2161+
elif asyncio.iscoroutinefunction(effect):
2162+
result = await effect(*args, **kwargs)
2163+
else:
2164+
result = effect(*args, **kwargs)
21602165

2161-
return await proxy()
2166+
if result is not DEFAULT:
2167+
return result
2168+
2169+
if self._mock_return_value is not DEFAULT:
2170+
return self.return_value
2171+
2172+
if self._mock_wraps is not None:
2173+
if asyncio.iscoroutinefunction(self._mock_wraps):
2174+
return await self._mock_wraps(*args, **kwargs)
2175+
return self._mock_wraps(*args, **kwargs)
2176+
2177+
return self.return_value
21622178

21632179
def assert_awaited(self):
21642180
"""
@@ -2864,10 +2880,6 @@ def seal(mock):
28642880
seal(m)
28652881

28662882

2867-
async def _raise(exception):
2868-
raise exception
2869-
2870-
28712883
class _AsyncIterator:
28722884
"""
28732885
Wraps an iterator in an asynchronous iterator.

Lib/unittest/test/testmock/testasync.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -358,42 +358,84 @@ def test_magicmock_lambda_spec(self):
358358
self.assertIsInstance(cm, MagicMock)
359359

360360

361-
class AsyncArguments(unittest.TestCase):
362-
def test_add_return_value(self):
361+
class AsyncArguments(unittest.IsolatedAsyncioTestCase):
362+
async def test_add_return_value(self):
363363
async def addition(self, var):
364364
return var + 1
365365

366366
mock = AsyncMock(addition, return_value=10)
367-
output = asyncio.run(mock(5))
367+
output = await mock(5)
368368

369369
self.assertEqual(output, 10)
370370

371-
def test_add_side_effect_exception(self):
371+
async def test_add_side_effect_exception(self):
372372
async def addition(var):
373373
return var + 1
374374
mock = AsyncMock(addition, side_effect=Exception('err'))
375375
with self.assertRaises(Exception):
376-
asyncio.run(mock(5))
376+
await mock(5)
377377

378-
def test_add_side_effect_function(self):
378+
async def test_add_side_effect_function(self):
379379
async def addition(var):
380380
return var + 1
381381
mock = AsyncMock(side_effect=addition)
382-
result = asyncio.run(mock(5))
382+
result = await mock(5)
383383
self.assertEqual(result, 6)
384384

385-
def test_add_side_effect_iterable(self):
385+
async def test_add_side_effect_iterable(self):
386386
vals = [1, 2, 3]
387387
mock = AsyncMock(side_effect=vals)
388388
for item in vals:
389-
self.assertEqual(item, asyncio.run(mock()))
390-
391-
with self.assertRaises(RuntimeError) as e:
392-
asyncio.run(mock())
393-
self.assertEqual(
394-
e.exception,
395-
RuntimeError('coroutine raised StopIteration')
396-
)
389+
self.assertEqual(item, await mock())
390+
391+
with self.assertRaises(StopAsyncIteration) as e:
392+
await mock()
393+
394+
async def test_return_value_AsyncMock(self):
395+
value = AsyncMock(return_value=10)
396+
mock = AsyncMock(return_value=value)
397+
result = await mock()
398+
self.assertIs(result, value)
399+
400+
async def test_return_value_awaitable(self):
401+
fut = asyncio.Future()
402+
fut.set_result(None)
403+
mock = AsyncMock(return_value=fut)
404+
result = await mock()
405+
self.assertIsInstance(result, asyncio.Future)
406+
407+
async def test_side_effect_awaitable_values(self):
408+
fut = asyncio.Future()
409+
fut.set_result(None)
410+
411+
mock = AsyncMock(side_effect=[fut])
412+
result = await mock()
413+
self.assertIsInstance(result, asyncio.Future)
414+
415+
with self.assertRaises(StopAsyncIteration):
416+
await mock()
417+
418+
async def test_side_effect_is_AsyncMock(self):
419+
effect = AsyncMock(return_value=10)
420+
mock = AsyncMock(side_effect=effect)
421+
422+
result = await mock()
423+
self.assertEqual(result, 10)
424+
425+
async def test_wraps_coroutine(self):
426+
value = asyncio.Future()
427+
428+
ran = False
429+
async def inner():
430+
nonlocal ran
431+
ran = True
432+
return value
433+
434+
mock = AsyncMock(wraps=inner)
435+
result = await mock()
436+
self.assertEqual(result, value)
437+
mock.assert_awaited()
438+
self.assertTrue(ran)
397439

398440
class AsyncMagicMethods(unittest.TestCase):
399441
def test_async_magic_methods_return_async_mocks(self):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
AsyncMock fix for return values that are awaitable types. This also covers
2+
side_effect iterable values that happend to be awaitable, and wraps
3+
callables that return an awaitable type. Before these awaitables were being
4+
awaited instead of being returned as is.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
AsyncMock now returns StopAsyncIteration on the exaustion of a side_effects
2+
iterable. Since PEP-479 its Impossible to raise a StopIteration exception
3+
from a coroutine.

0 commit comments

Comments
 (0)