Skip to content

Commit ff8b788

Browse files
committed
Type annotate ParameterSet
1 parent 43fa1ee commit ff8b788

File tree

4 files changed

+85
-20
lines changed

4 files changed

+85
-20
lines changed

src/_pytest/compat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
python version compatibility code
33
"""
4+
import enum
45
import functools
56
import inspect
67
import os
@@ -33,13 +34,20 @@
3334

3435
if TYPE_CHECKING:
3536
from typing import Type
37+
from typing_extensions import Final
3638

3739

3840
_T = TypeVar("_T")
3941
_S = TypeVar("_S")
4042

4143

42-
NOTSET = object()
44+
# fmt: off
45+
# Singleton type for NOTSET, as described in:
46+
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
47+
class NotSetType(enum.Enum):
48+
token = 0
49+
NOTSET = NotSetType.token # type: Final # noqa: E305
50+
# fmt: on
4351

4452
MODULE_NOT_FOUND_ERROR = (
4553
"ModuleNotFoundError" if sys.version_info[:2] >= (3, 6) else "ImportError"

src/_pytest/mark/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
""" generic mechanism for marking and selecting python functions. """
2+
import typing
23
import warnings
34
from typing import AbstractSet
45
from typing import Optional
6+
from typing import Union
57

68
import attr
79

@@ -31,7 +33,11 @@
3133
old_mark_config_key = StoreKey[Optional[Config]]()
3234

3335

34-
def param(*values, **kw):
36+
def param(
37+
*values: object,
38+
marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
39+
id: Optional[str] = None
40+
) -> ParameterSet:
3541
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
3642
:ref:`parametrized fixtures <fixture-parametrize-marks>`.
3743
@@ -48,7 +54,7 @@ def test_eval(test_input, expected):
4854
:keyword marks: a single mark or a list of marks to be applied to this parameter set.
4955
:keyword str id: the id to attribute to this parameter set.
5056
"""
51-
return ParameterSet.param(*values, **kw)
57+
return ParameterSet.param(*values, marks=marks, id=id)
5258

5359

5460
def pytest_addoption(parser):

src/_pytest/mark/structures.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import collections.abc
12
import inspect
3+
import typing
24
import warnings
3-
from collections import namedtuple
4-
from collections.abc import MutableMapping
55
from typing import Any
66
from typing import Iterable
77
from typing import List
88
from typing import Mapping
9+
from typing import NamedTuple
910
from typing import Optional
1011
from typing import Sequence
1112
from typing import Set
@@ -17,20 +18,29 @@
1718
from .._code import getfslineno
1819
from ..compat import ascii_escaped
1920
from ..compat import NOTSET
21+
from ..compat import NotSetType
22+
from ..compat import TYPE_CHECKING
23+
from _pytest.config import Config
2024
from _pytest.outcomes import fail
2125
from _pytest.warning_types import PytestUnknownMarkWarning
2226

27+
if TYPE_CHECKING:
28+
from _pytest.python import FunctionDefinition
29+
30+
2331
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
2432

2533

26-
def istestfunc(func):
34+
def istestfunc(func) -> bool:
2735
return (
2836
hasattr(func, "__call__")
2937
and getattr(func, "__name__", "<lambda>") != "<lambda>"
3038
)
3139

3240

33-
def get_empty_parameterset_mark(config, argnames, func):
41+
def get_empty_parameterset_mark(
42+
config: Config, argnames: Sequence[str], func
43+
) -> "MarkDecorator":
3444
from ..nodes import Collector
3545

3646
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
@@ -53,16 +63,33 @@ def get_empty_parameterset_mark(config, argnames, func):
5363
fs,
5464
lineno,
5565
)
56-
return mark(reason=reason)
57-
58-
59-
class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
66+
# Type ignored because MarkDecorator.__call__() is a bit tough to
67+
# annotate ATM.
68+
return mark(reason=reason) # type: ignore[no-any-return] # noqa: F723
69+
70+
71+
class ParameterSet(
72+
NamedTuple(
73+
"ParameterSet",
74+
[
75+
("values", Sequence[Union[object, NotSetType]]),
76+
("marks", "typing.Collection[Union[MarkDecorator, Mark]]"),
77+
("id", Optional[str]),
78+
],
79+
)
80+
):
6081
@classmethod
61-
def param(cls, *values, marks=(), id=None):
82+
def param(
83+
cls,
84+
*values: object,
85+
marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
86+
id: Optional[str] = None
87+
) -> "ParameterSet":
6288
if isinstance(marks, MarkDecorator):
6389
marks = (marks,)
6490
else:
65-
assert isinstance(marks, (tuple, list, set))
91+
# TODO(py36): Change to collections.abc.Collection.
92+
assert isinstance(marks, (collections.abc.Sequence, set))
6693

6794
if id is not None:
6895
if not isinstance(id, str):
@@ -73,7 +100,11 @@ def param(cls, *values, marks=(), id=None):
73100
return cls(values, marks, id)
74101

75102
@classmethod
76-
def extract_from(cls, parameterset, force_tuple=False):
103+
def extract_from(
104+
cls,
105+
parameterset: Union["ParameterSet", Sequence[object], object],
106+
force_tuple: bool = False,
107+
) -> "ParameterSet":
77108
"""
78109
:param parameterset:
79110
a legacy style parameterset that may or may not be a tuple,
@@ -89,10 +120,20 @@ def extract_from(cls, parameterset, force_tuple=False):
89120
if force_tuple:
90121
return cls.param(parameterset)
91122
else:
92-
return cls(parameterset, marks=[], id=None)
123+
# TODO: Refactor to fix this type-ignore. Currently the following
124+
# type-checks but crashes:
125+
#
126+
# @pytest.mark.parametrize(('x', 'y'), [1, 2])
127+
# def test_foo(x, y): pass
128+
return cls(parameterset, marks=[], id=None) # type: ignore[arg-type] # noqa: F821
93129

94130
@staticmethod
95-
def _parse_parametrize_args(argnames, argvalues, *args, **kwargs):
131+
def _parse_parametrize_args(
132+
argnames: Union[str, List[str], Tuple[str, ...]],
133+
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
134+
*args,
135+
**kwargs
136+
) -> Tuple[Union[List[str], Tuple[str, ...]], bool]:
96137
if not isinstance(argnames, (tuple, list)):
97138
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
98139
force_tuple = len(argnames) == 1
@@ -101,13 +142,23 @@ def _parse_parametrize_args(argnames, argvalues, *args, **kwargs):
101142
return argnames, force_tuple
102143

103144
@staticmethod
104-
def _parse_parametrize_parameters(argvalues, force_tuple):
145+
def _parse_parametrize_parameters(
146+
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
147+
force_tuple: bool,
148+
) -> List["ParameterSet"]:
105149
return [
106150
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
107151
]
108152

109153
@classmethod
110-
def _for_parametrize(cls, argnames, argvalues, func, config, function_definition):
154+
def _for_parametrize(
155+
cls,
156+
argnames: Union[str, List[str], Tuple[str, ...]],
157+
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
158+
func,
159+
config: Config,
160+
function_definition: "FunctionDefinition",
161+
) -> Tuple[Union[List[str], Tuple[str, ...]], List["ParameterSet"]]:
111162
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
112163
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
113164
del argvalues
@@ -370,7 +421,7 @@ def __getattr__(self, name: str) -> MarkDecorator:
370421
MARK_GEN = MarkGenerator()
371422

372423

373-
class NodeKeywords(MutableMapping):
424+
class NodeKeywords(collections.abc.MutableMapping):
374425
def __init__(self, node):
375426
self.node = node
376427
self.parent = node.parent

testing/test_doctest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ def test_number_precision(self, testdir, config_mode):
10511051
("1e3", "999"),
10521052
# The current implementation doesn't understand that numbers inside
10531053
# strings shouldn't be treated as numbers:
1054-
pytest.param("'3.1416'", "'3.14'", marks=pytest.mark.xfail),
1054+
pytest.param("'3.1416'", "'3.14'", marks=pytest.mark.xfail), # type: ignore
10551055
],
10561056
)
10571057
def test_number_non_matches(self, testdir, expression, output):

0 commit comments

Comments
 (0)