Skip to content

Commit 7ea173c

Browse files
authored
Fix @patch when new is missing (#10459)
1 parent 1d7f0d0 commit 7ea173c

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

stdlib/unittest/mock.pyi

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class _patch(Generic[_T]):
234234
def copy(self) -> _patch[_T]: ...
235235
@overload
236236
def __call__(self, func: _TT) -> _TT: ...
237+
# If new==DEFAULT, this should add a MagicMock parameter to the function
238+
# arguments. See the _patch_default_new class below for this functionality.
237239
@overload
238240
def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ...
239241
if sys.version_info >= (3, 8):
@@ -257,6 +259,22 @@ class _patch(Generic[_T]):
257259
def start(self) -> _T: ...
258260
def stop(self) -> None: ...
259261

262+
if sys.version_info >= (3, 8):
263+
_Mock: TypeAlias = MagicMock | AsyncMock
264+
else:
265+
_Mock: TypeAlias = MagicMock
266+
267+
# This class does not exist at runtime, it's a hack to make this work:
268+
# @patch("foo")
269+
# def bar(..., mock: MagicMock) -> None: ...
270+
class _patch_default_new(_patch[_Mock]):
271+
@overload
272+
def __call__(self, func: _TT) -> _TT: ...
273+
# Can't use the following as ParamSpec is only allowed as last parameter:
274+
# def __call__(self, func: Callable[_P, _R]) -> Callable[Concatenate[_P, MagicMock], _R]: ...
275+
@overload
276+
def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ...
277+
260278
class _patch_dict:
261279
in_dict: Any
262280
values: Any
@@ -273,11 +291,8 @@ class _patch_dict:
273291
start: Any
274292
stop: Any
275293

276-
if sys.version_info >= (3, 8):
277-
_Mock: TypeAlias = MagicMock | AsyncMock
278-
else:
279-
_Mock: TypeAlias = MagicMock
280-
294+
# This class does not exist at runtime, it's a hack to add methods to the
295+
# patch() function.
281296
class _patcher:
282297
TEST_PREFIX: str
283298
dict: type[_patch_dict]
@@ -307,7 +322,7 @@ class _patcher:
307322
autospec: Any | None = ...,
308323
new_callable: Any | None = ...,
309324
**kwargs: Any,
310-
) -> _patch[_Mock]: ...
325+
) -> _patch_default_new: ...
311326
@overload
312327
@staticmethod
313328
def object( # type: ignore[misc]

test_cases/stdlib/check_unittest.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from decimal import Decimal
66
from fractions import Fraction
77
from typing_extensions import assert_type
8-
from unittest.mock import Mock, patch
8+
from unittest.mock import MagicMock, Mock, patch
99

1010
case = unittest.TestCase()
1111

@@ -94,13 +94,20 @@ def __gt__(self, other: Bacon) -> bool:
9494
###
9595

9696

97-
@patch("sys.exit", new=Mock())
98-
def f(i: int) -> str:
97+
@patch("sys.exit")
98+
def f_default_new(i: int, mock: MagicMock) -> str:
99+
return "asdf"
100+
101+
102+
@patch("sys.exit", new=42)
103+
def f_explicit_new(i: int) -> str:
99104
return "asdf"
100105

101106

102-
assert_type(f(1), str)
103-
f("a") # type: ignore
107+
assert_type(f_default_new(1), str)
108+
f_default_new("a") # Not an error due to ParamSpec limitations
109+
assert_type(f_explicit_new(1), str)
110+
f_explicit_new("a") # type: ignore[arg-type]
104111

105112

106113
@patch("sys.exit", new=Mock())

0 commit comments

Comments
 (0)