Skip to content

Commit 9eaea96

Browse files
authored
Ruggedize cmd2's runtime type annotation validation (#1442)
* Change cmd2's runtime type annotation validation to be based on typing.get_type_hints Previously it was based on inspect.signature. The problem is that to Python 3.10, the inspect module doesn't have a safe way of evaluating type annotations that works equivalently both in the presence or absence of "from __future__ import annotations". Hence, any attempt at using that in an app would break cmd2. This change adds a get_types() helper function to the cmd2.utils module which uses typing.get_type_hints() to do the introspection in a safer way. * Fix for Python 3.9 since types.NoneType doesn't exist until 3.10 * Added unit tests for new utils.get_types function * Add a note to the CHANGELOG * Added a new test file to cover the case representing the whole reason for this change I first verified that this test fails on the current master branch * Switched CHANGELOG type to Bug Fix
1 parent 8d1c82c commit 9eaea96

File tree

5 files changed

+129
-47
lines changed

5 files changed

+129
-47
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 2.6.1 (TBD)
2+
3+
- Bug Fixes
4+
- Fix bug that prevented `cmd2` from working with `from __future__ import annotations`
5+
16
## 2.6.0 (May 31, 2025)
27

38
- Breaking Change

cmd2/cmd2.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
from .utils import (
145145
Settable,
146146
get_defining_class,
147+
get_types,
147148
strip_doc_annotations,
148149
suggest_similar,
149150
)
@@ -5544,10 +5545,10 @@ def _validate_callable_param_count(cls, func: Callable[..., Any], count: int) ->
55445545
def _validate_prepostloop_callable(cls, func: Callable[[], None]) -> None:
55455546
"""Check parameter and return types for preloop and postloop hooks."""
55465547
cls._validate_callable_param_count(func, 0)
5547-
# make sure there is no return notation
5548-
signature = inspect.signature(func)
5549-
if signature.return_annotation is not None:
5550-
raise TypeError(f"{func.__name__} must declare return a return type of 'None'")
5548+
# make sure there is no return annotation or the return is specified as None
5549+
_, ret_ann = get_types(func)
5550+
if ret_ann is not None:
5551+
raise TypeError(f"{func.__name__} must have a return type of 'None', got: {ret_ann}")
55515552

55525553
def register_preloop_hook(self, func: Callable[[], None]) -> None:
55535554
"""Register a function to be called at the beginning of the command loop."""
@@ -5563,11 +5564,13 @@ def register_postloop_hook(self, func: Callable[[], None]) -> None:
55635564
def _validate_postparsing_callable(cls, func: Callable[[plugin.PostparsingData], plugin.PostparsingData]) -> None:
55645565
"""Check parameter and return types for postparsing hooks."""
55655566
cls._validate_callable_param_count(cast(Callable[..., Any], func), 1)
5566-
signature = inspect.signature(func)
5567-
_, param = next(iter(signature.parameters.items()))
5568-
if param.annotation != plugin.PostparsingData:
5567+
type_hints, ret_ann = get_types(func)
5568+
if not type_hints:
5569+
raise TypeError(f"{func.__name__} parameter is missing a type hint, expected: 'cmd2.plugin.PostparsingData'")
5570+
par_ann = next(iter(type_hints.values()))
5571+
if par_ann != plugin.PostparsingData:
55695572
raise TypeError(f"{func.__name__} must have one parameter declared with type 'cmd2.plugin.PostparsingData'")
5570-
if signature.return_annotation != plugin.PostparsingData:
5573+
if ret_ann != plugin.PostparsingData:
55715574
raise TypeError(f"{func.__name__} must declare return a return type of 'cmd2.plugin.PostparsingData'")
55725575

55735576
def register_postparsing_hook(self, func: Callable[[plugin.PostparsingData], plugin.PostparsingData]) -> None:
@@ -5582,21 +5585,21 @@ def _validate_prepostcmd_hook(
55825585
cls, func: Callable[[CommandDataType], CommandDataType], data_type: type[CommandDataType]
55835586
) -> None:
55845587
"""Check parameter and return types for pre and post command hooks."""
5585-
signature = inspect.signature(func)
55865588
# validate that the callable has the right number of parameters
55875589
cls._validate_callable_param_count(cast(Callable[..., Any], func), 1)
5590+
5591+
type_hints, ret_ann = get_types(func)
5592+
if not type_hints:
5593+
raise TypeError(f"{func.__name__} parameter is missing a type hint, expected: {data_type}")
5594+
param_name, par_ann = next(iter(type_hints.items()))
55885595
# validate the parameter has the right annotation
5589-
paramname = next(iter(signature.parameters.keys()))
5590-
param = signature.parameters[paramname]
5591-
if param.annotation != data_type:
5592-
raise TypeError(f'argument 1 of {func.__name__} has incompatible type {param.annotation}, expected {data_type}')
5596+
if par_ann != data_type:
5597+
raise TypeError(f'argument 1 of {func.__name__} has incompatible type {par_ann}, expected {data_type}')
55935598
# validate the return value has the right annotation
5594-
if signature.return_annotation == signature.empty:
5599+
if ret_ann is None:
55955600
raise TypeError(f'{func.__name__} does not have a declared return type, expected {data_type}')
5596-
if signature.return_annotation != data_type:
5597-
raise TypeError(
5598-
f'{func.__name__} has incompatible return type {signature.return_annotation}, expected {data_type}'
5599-
)
5601+
if ret_ann != data_type:
5602+
raise TypeError(f'{func.__name__} has incompatible return type {ret_ann}, expected {data_type}')
56005603

56015604
def register_precmd_hook(self, func: Callable[[plugin.PrecommandData], plugin.PrecommandData]) -> None:
56025605
"""Register a hook to be called before the command function."""
@@ -5614,12 +5617,16 @@ def _validate_cmdfinalization_callable(
56145617
) -> None:
56155618
"""Check parameter and return types for command finalization hooks."""
56165619
cls._validate_callable_param_count(func, 1)
5617-
signature = inspect.signature(func)
5618-
_, param = next(iter(signature.parameters.items()))
5619-
if param.annotation != plugin.CommandFinalizationData:
5620-
raise TypeError(f"{func.__name__} must have one parameter declared with type {plugin.CommandFinalizationData}")
5621-
if signature.return_annotation != plugin.CommandFinalizationData:
5622-
raise TypeError("{func.__name__} must declare return a return type of {plugin.CommandFinalizationData}")
5620+
type_hints, ret_ann = get_types(func)
5621+
if not type_hints:
5622+
raise TypeError(f"{func.__name__} parameter is missing a type hint, expected: {plugin.CommandFinalizationData}")
5623+
_, par_ann = next(iter(type_hints.items()))
5624+
if par_ann != plugin.CommandFinalizationData:
5625+
raise TypeError(
5626+
f"{func.__name__} must have one parameter declared with type {plugin.CommandFinalizationData}, got: {par_ann}"
5627+
)
5628+
if ret_ann != plugin.CommandFinalizationData:
5629+
raise TypeError(f"{func.__name__} must declare return a return type of {plugin.CommandFinalizationData}")
56235630

56245631
def register_cmdfinalization_hook(
56255632
self, func: Callable[[plugin.CommandFinalizationData], plugin.CommandFinalizationData]

cmd2/utils.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,12 @@
1414
import threading
1515
import unicodedata
1616
from collections.abc import Callable, Iterable
17-
from difflib import (
18-
SequenceMatcher,
19-
)
20-
from enum import (
21-
Enum,
22-
)
23-
from typing import (
24-
TYPE_CHECKING,
25-
Any,
26-
Optional,
27-
TextIO,
28-
TypeVar,
29-
Union,
30-
cast,
31-
)
32-
33-
from . import (
34-
constants,
35-
)
36-
from .argparse_custom import (
37-
ChoicesProviderFunc,
38-
CompleterFunc,
39-
)
17+
from difflib import SequenceMatcher
18+
from enum import Enum
19+
from typing import TYPE_CHECKING, Any, Optional, TextIO, TypeVar, Union, cast, get_type_hints
20+
21+
from . import constants
22+
from .argparse_custom import ChoicesProviderFunc, CompleterFunc
4023

4124
if TYPE_CHECKING: # pragma: no cover
4225
import cmd2 # noqa: F401
@@ -1261,3 +1244,27 @@ def suggest_similar(
12611244
best_simil = simil
12621245
proposed_command = each
12631246
return proposed_command
1247+
1248+
1249+
def get_types(func_or_method: Callable[..., Any]) -> tuple[dict[str, Any], Any]:
1250+
"""Use typing.get_type_hints() to extract type hints for parameters and return value.
1251+
1252+
This exists because the inspect module doesn't have a safe way of doing this that works
1253+
both with and without importing annotations from __future__ until Python 3.10.
1254+
1255+
TODO: Once cmd2 only supports Python 3.10+, change to use inspect.get_annotations(eval_str=True)
1256+
1257+
:param func_or_method: Function or method to return the type hints for
1258+
:return tuple with first element being dictionary mapping param names to type hints
1259+
and second element being return type hint, unspecified, returns None
1260+
"""
1261+
try:
1262+
type_hints = get_type_hints(func_or_method) # Get dictionary of type hints
1263+
except TypeError as exc:
1264+
raise ValueError("Argument passed to get_types should be a function or method") from exc
1265+
ret_ann = type_hints.pop('return', None) # Pop off the return annotation if it exists
1266+
if inspect.ismethod(func_or_method):
1267+
type_hints.pop('self', None) # Pop off `self` hint for methods
1268+
if ret_ann is type(None):
1269+
ret_ann = None # Simplify logic to just return None instead of NoneType
1270+
return type_hints, ret_ann

tests/test_future_annotations.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import cmd2
4+
5+
from .conftest import normalize, run_cmd
6+
7+
8+
def test_hooks_work_with_future_annotations() -> None:
9+
class HookApp(cmd2.Cmd):
10+
def __init__(self, *args, **kwargs) -> None:
11+
super().__init__(*args, **kwargs)
12+
self.register_cmdfinalization_hook(self.hook)
13+
14+
def hook(self: cmd2.Cmd, data: cmd2.plugin.CommandFinalizationData) -> cmd2.plugin.CommandFinalizationData:
15+
if self.in_script():
16+
self.poutput("WE ARE IN SCRIPT")
17+
return data
18+
19+
hook_app = HookApp()
20+
out, err = run_cmd(hook_app, '')
21+
expected = normalize('')
22+
assert out == expected

tests/test_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,3 +892,44 @@ def custom_similarity_function(s1, s2) -> float:
892892

893893
suggested_command = cu.suggest_similar("test", ["test"], similarity_function_to_use=custom_similarity_function)
894894
assert suggested_command is None
895+
896+
897+
def test_get_types_invalid_input() -> None:
898+
x = 1
899+
with pytest.raises(ValueError, match="Argument passed to get_types should be a function or method"):
900+
cu.get_types(x)
901+
902+
903+
def test_get_types_empty() -> None:
904+
def a(b):
905+
print(b)
906+
907+
param_ann, ret_ann = cu.get_types(a)
908+
assert ret_ann is None
909+
assert param_ann == {}
910+
911+
912+
def test_get_types_non_empty() -> None:
913+
def foo(x: int) -> str:
914+
return f"{x * x}"
915+
916+
param_ann, ret_ann = cu.get_types(foo)
917+
assert ret_ann is str
918+
param_name, param_value = next(iter(param_ann.items()))
919+
assert param_name == 'x'
920+
assert param_value is int
921+
922+
923+
def test_get_types_method() -> None:
924+
class Foo:
925+
def bar(self, x: bool) -> None:
926+
print(x)
927+
928+
f = Foo()
929+
930+
param_ann, ret_ann = cu.get_types(f.bar)
931+
assert ret_ann is None
932+
assert len(param_ann) == 1
933+
param_name, param_value = next(iter(param_ann.items()))
934+
assert param_name == 'x'
935+
assert param_value is bool

0 commit comments

Comments
 (0)