Skip to content

Ruggedize cmd2's runtime type annotation validation #1442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 2.6.1 (TBD)

- Bug Fixes
- Fix bug that prevented `cmd2` from working with `from __future__ import annotations`

## 2.6.0 (May 31, 2025)

- Breaking Change
Expand Down
55 changes: 31 additions & 24 deletions cmd2/cmd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
from .utils import (
Settable,
get_defining_class,
get_types,
strip_doc_annotations,
suggest_similar,
)
Expand Down Expand Up @@ -5544,10 +5545,10 @@ def _validate_callable_param_count(cls, func: Callable[..., Any], count: int) ->
def _validate_prepostloop_callable(cls, func: Callable[[], None]) -> None:
"""Check parameter and return types for preloop and postloop hooks."""
cls._validate_callable_param_count(func, 0)
# make sure there is no return notation
signature = inspect.signature(func)
if signature.return_annotation is not None:
raise TypeError(f"{func.__name__} must declare return a return type of 'None'")
# make sure there is no return annotation or the return is specified as None
_, ret_ann = get_types(func)
if ret_ann is not None:
raise TypeError(f"{func.__name__} must have a return type of 'None', got: {ret_ann}")

def register_preloop_hook(self, func: Callable[[], None]) -> None:
"""Register a function to be called at the beginning of the command loop."""
Expand All @@ -5563,11 +5564,13 @@ def register_postloop_hook(self, func: Callable[[], None]) -> None:
def _validate_postparsing_callable(cls, func: Callable[[plugin.PostparsingData], plugin.PostparsingData]) -> None:
"""Check parameter and return types for postparsing hooks."""
cls._validate_callable_param_count(cast(Callable[..., Any], func), 1)
signature = inspect.signature(func)
_, param = next(iter(signature.parameters.items()))
if param.annotation != plugin.PostparsingData:
type_hints, ret_ann = get_types(func)
if not type_hints:
raise TypeError(f"{func.__name__} parameter is missing a type hint, expected: 'cmd2.plugin.PostparsingData'")
par_ann = next(iter(type_hints.values()))
if par_ann != plugin.PostparsingData:
raise TypeError(f"{func.__name__} must have one parameter declared with type 'cmd2.plugin.PostparsingData'")
if signature.return_annotation != plugin.PostparsingData:
if ret_ann != plugin.PostparsingData:
raise TypeError(f"{func.__name__} must declare return a return type of 'cmd2.plugin.PostparsingData'")

def register_postparsing_hook(self, func: Callable[[plugin.PostparsingData], plugin.PostparsingData]) -> None:
Expand All @@ -5582,21 +5585,21 @@ def _validate_prepostcmd_hook(
cls, func: Callable[[CommandDataType], CommandDataType], data_type: type[CommandDataType]
) -> None:
"""Check parameter and return types for pre and post command hooks."""
signature = inspect.signature(func)
# validate that the callable has the right number of parameters
cls._validate_callable_param_count(cast(Callable[..., Any], func), 1)

type_hints, ret_ann = get_types(func)
if not type_hints:
raise TypeError(f"{func.__name__} parameter is missing a type hint, expected: {data_type}")
param_name, par_ann = next(iter(type_hints.items()))
# validate the parameter has the right annotation
paramname = next(iter(signature.parameters.keys()))
param = signature.parameters[paramname]
if param.annotation != data_type:
raise TypeError(f'argument 1 of {func.__name__} has incompatible type {param.annotation}, expected {data_type}')
if par_ann != data_type:
raise TypeError(f'argument 1 of {func.__name__} has incompatible type {par_ann}, expected {data_type}')
# validate the return value has the right annotation
if signature.return_annotation == signature.empty:
if ret_ann is None:
raise TypeError(f'{func.__name__} does not have a declared return type, expected {data_type}')
if signature.return_annotation != data_type:
raise TypeError(
f'{func.__name__} has incompatible return type {signature.return_annotation}, expected {data_type}'
)
if ret_ann != data_type:
raise TypeError(f'{func.__name__} has incompatible return type {ret_ann}, expected {data_type}')

def register_precmd_hook(self, func: Callable[[plugin.PrecommandData], plugin.PrecommandData]) -> None:
"""Register a hook to be called before the command function."""
Expand All @@ -5614,12 +5617,16 @@ def _validate_cmdfinalization_callable(
) -> None:
"""Check parameter and return types for command finalization hooks."""
cls._validate_callable_param_count(func, 1)
signature = inspect.signature(func)
_, param = next(iter(signature.parameters.items()))
if param.annotation != plugin.CommandFinalizationData:
raise TypeError(f"{func.__name__} must have one parameter declared with type {plugin.CommandFinalizationData}")
if signature.return_annotation != plugin.CommandFinalizationData:
raise TypeError("{func.__name__} must declare return a return type of {plugin.CommandFinalizationData}")
type_hints, ret_ann = get_types(func)
if not type_hints:
raise TypeError(f"{func.__name__} parameter is missing a type hint, expected: {plugin.CommandFinalizationData}")
_, par_ann = next(iter(type_hints.items()))
if par_ann != plugin.CommandFinalizationData:
raise TypeError(
f"{func.__name__} must have one parameter declared with type {plugin.CommandFinalizationData}, got: {par_ann}"
)
if ret_ann != plugin.CommandFinalizationData:
raise TypeError(f"{func.__name__} must declare return a return type of {plugin.CommandFinalizationData}")

def register_cmdfinalization_hook(
self, func: Callable[[plugin.CommandFinalizationData], plugin.CommandFinalizationData]
Expand Down
53 changes: 30 additions & 23 deletions cmd2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,12 @@
import threading
import unicodedata
from collections.abc import Callable, Iterable
from difflib import (
SequenceMatcher,
)
from enum import (
Enum,
)
from typing import (
TYPE_CHECKING,
Any,
Optional,
TextIO,
TypeVar,
Union,
cast,
)

from . import (
constants,
)
from .argparse_custom import (
ChoicesProviderFunc,
CompleterFunc,
)
from difflib import SequenceMatcher
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, TextIO, TypeVar, Union, cast, get_type_hints

from . import constants
from .argparse_custom import ChoicesProviderFunc, CompleterFunc

if TYPE_CHECKING: # pragma: no cover
import cmd2 # noqa: F401
Expand Down Expand Up @@ -1261,3 +1244,27 @@ def suggest_similar(
best_simil = simil
proposed_command = each
return proposed_command


def get_types(func_or_method: Callable[..., Any]) -> tuple[dict[str, Any], Any]:
"""Use typing.get_type_hints() to extract type hints for parameters and return value.

This exists because the inspect module doesn't have a safe way of doing this that works
both with and without importing annotations from __future__ until Python 3.10.

TODO: Once cmd2 only supports Python 3.10+, change to use inspect.get_annotations(eval_str=True)

:param func_or_method: Function or method to return the type hints for
:return tuple with first element being dictionary mapping param names to type hints
and second element being return type hint, unspecified, returns None
"""
try:
type_hints = get_type_hints(func_or_method) # Get dictionary of type hints
except TypeError as exc:
raise ValueError("Argument passed to get_types should be a function or method") from exc
ret_ann = type_hints.pop('return', None) # Pop off the return annotation if it exists
if inspect.ismethod(func_or_method):
type_hints.pop('self', None) # Pop off `self` hint for methods
if ret_ann is type(None):
ret_ann = None # Simplify logic to just return None instead of NoneType
return type_hints, ret_ann
22 changes: 22 additions & 0 deletions tests/test_future_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import cmd2

from .conftest import normalize, run_cmd


def test_hooks_work_with_future_annotations() -> None:
class HookApp(cmd2.Cmd):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.register_cmdfinalization_hook(self.hook)

def hook(self: cmd2.Cmd, data: cmd2.plugin.CommandFinalizationData) -> cmd2.plugin.CommandFinalizationData:
if self.in_script():
self.poutput("WE ARE IN SCRIPT")
return data

hook_app = HookApp()
out, err = run_cmd(hook_app, '')
expected = normalize('')
assert out == expected
41 changes: 41 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,44 @@ def custom_similarity_function(s1, s2) -> float:

suggested_command = cu.suggest_similar("test", ["test"], similarity_function_to_use=custom_similarity_function)
assert suggested_command is None


def test_get_types_invalid_input() -> None:
x = 1
with pytest.raises(ValueError, match="Argument passed to get_types should be a function or method"):
cu.get_types(x)


def test_get_types_empty() -> None:
def a(b):
print(b)

param_ann, ret_ann = cu.get_types(a)
assert ret_ann is None
assert param_ann == {}


def test_get_types_non_empty() -> None:
def foo(x: int) -> str:
return f"{x * x}"

param_ann, ret_ann = cu.get_types(foo)
assert ret_ann is str
param_name, param_value = next(iter(param_ann.items()))
assert param_name == 'x'
assert param_value is int


def test_get_types_method() -> None:
class Foo:
def bar(self, x: bool) -> None:
print(x)

f = Foo()

param_ann, ret_ann = cu.get_types(f.bar)
assert ret_ann is None
assert len(param_ann) == 1
param_name, param_value = next(iter(param_ann.items()))
assert param_name == 'x'
assert param_value is bool
Loading