Skip to content

Commit 5e333b1

Browse files
committed
Step 5
1 parent 10b9b01 commit 5e333b1

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

.mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ disallow_any_generics = True
2020
# These next few are various gradations of forcing use of type annotations
2121
disallow_untyped_calls = True
2222
disallow_incomplete_defs = True
23-
; disallow_untyped_defs = True
23+
disallow_untyped_defs = True
2424

2525
# This one isn't too hard to get passing, but return on investment is lower
2626
no_implicit_reexport = True

django_fsm/__init__.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
IntegerField = models.IntegerField[int, int]
5050
ForeignKey = models.ForeignKey[Any, Any]
5151

52+
_StateValue = str | int
5253
_Instance = models.Model # TODO: use real type
5354
_ToDo = Any # TODO: use real type
5455
else:
@@ -83,10 +84,10 @@ class ConcurrentTransition(Exception):
8384
class Transition:
8485
def __init__(
8586
self,
86-
method: Callable[..., str | int | None],
87-
source: str | int | Sequence[str | int] | State,
88-
target: str | int,
89-
on_error: str | int | None,
87+
method: Callable[..., _StateValue | Any],
88+
source: _StateValue | Sequence[_StateValue] | State,
89+
target: _StateValue,
90+
on_error: _StateValue | None,
9091
conditions: list[Callable[[_Instance], bool]],
9192
permission: str | Callable[[_Instance, UserWithPermissions], bool] | None,
9293
custom: dict[str, _StrOrPromise],
@@ -414,7 +415,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
414415
if not issubclass(sender, self.base_cls):
415416
return
416417

417-
def is_field_transition_method(attr):
418+
def is_field_transition_method(attr: _ToDo) -> bool:
418419
return (
419420
(inspect.ismethod(attr) or inspect.isfunction(attr))
420421
and hasattr(attr, "_django_fsm")
@@ -528,7 +529,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
528529
def state_fields(self) -> Iterable[Any]:
529530
return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields)
530531

531-
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update):
532+
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): # type: ignore[no-untyped-def]
532533
# _do_update is called once for each model class in the inheritance hierarchy.
533534
# We can only filter the base_qs on state fields (can be more than one!) present in this particular model.
534535

@@ -572,21 +573,21 @@ def save(self, *args: Any, **kwargs: Any) -> None:
572573

573574
def transition(
574575
field: FSMFieldMixin,
575-
source: str | int | Sequence[str | int] | State = "*",
576+
source: str | int | Sequence[str | int] = "*",
576577
target: str | int | State | None = None,
577578
on_error: str | int | None = None,
578579
conditions: list[Callable[[Any], bool]] = [],
579580
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
580581
custom: dict[str, _StrOrPromise] = {},
581-
) -> _ToDo:
582+
) -> Callable[[Any], Any]:
582583
"""
583584
Method decorator to mark allowed transitions.
584585
585586
Set target to None if current state needs to be validated and
586587
has not changed after the function call.
587588
"""
588589

589-
def inner_transition(func):
590+
def inner_transition(func: _ToDo) -> _ToDo:
590591
wrapper_installed, fsm_meta = True, getattr(func, "_django_fsm", None)
591592
if not fsm_meta:
592593
wrapper_installed = False
@@ -647,15 +648,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
647648

648649

649650
class State:
650-
def get_state(self, model, transition, result, args=[], kwargs={}):
651+
def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
651652
raise NotImplementedError
652653

653654

654655
class RETURN_VALUE(State):
655656
def __init__(self, *allowed_states: Sequence[str | int]) -> None:
656657
self.allowed_states = allowed_states if allowed_states else None
657658

658-
def get_state(self, model, transition, result, args=[], kwargs={}):
659+
def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
659660
if self.allowed_states is not None:
660661
if result not in self.allowed_states:
661662
raise InvalidResultState(f"{result} is not in list of allowed states\n{self.allowed_states}")
@@ -667,7 +668,9 @@ def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] |
667668
self.func = func
668669
self.allowed_states = states
669670

670-
def get_state(self, model, transition, result, args=[], kwargs={}):
671+
def get_state(
672+
self, model: _Model, transition: Transition, result: _StateValue | Any, args: Any = [], kwargs: Any = {}
673+
) -> _ToDo:
671674
result_state = self.func(model, *args, **kwargs)
672675
if self.allowed_states is not None:
673676
if result_state not in self.allowed_states:

0 commit comments

Comments
 (0)