49
49
IntegerField = models .IntegerField [int , int ]
50
50
ForeignKey = models .ForeignKey [Any , Any ]
51
51
52
+ _StateValue = str | int
52
53
_Instance = models .Model # TODO: use real type
53
54
_ToDo = Any # TODO: use real type
54
55
else :
@@ -83,10 +84,10 @@ class ConcurrentTransition(Exception):
83
84
class Transition :
84
85
def __init__ (
85
86
self ,
86
- method : Callable [..., str | int | None ],
87
- source : str | int | Sequence [str | int ] | State ,
88
- target : str | int ,
89
- on_error : str | int | None ,
87
+ method : Callable [..., _StateValue | Any ],
88
+ source : _StateValue | Sequence [_StateValue ] | State ,
89
+ target : _StateValue ,
90
+ on_error : _StateValue | None ,
90
91
conditions : list [Callable [[_Instance ], bool ]],
91
92
permission : str | Callable [[_Instance , UserWithPermissions ], bool ] | None ,
92
93
custom : dict [str , _StrOrPromise ],
@@ -414,7 +415,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
414
415
if not issubclass (sender , self .base_cls ):
415
416
return
416
417
417
- def is_field_transition_method (attr ) :
418
+ def is_field_transition_method (attr : _ToDo ) -> bool :
418
419
return (
419
420
(inspect .ismethod (attr ) or inspect .isfunction (attr ))
420
421
and hasattr (attr , "_django_fsm" )
@@ -528,7 +529,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
528
529
def state_fields (self ) -> Iterable [Any ]:
529
530
return filter (lambda field : isinstance (field , FSMFieldMixin ), self ._meta .fields )
530
531
531
- def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ):
532
+ def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ): # type: ignore[no-untyped-def]
532
533
# _do_update is called once for each model class in the inheritance hierarchy.
533
534
# We can only filter the base_qs on state fields (can be more than one!) present in this particular model.
534
535
@@ -572,21 +573,21 @@ def save(self, *args: Any, **kwargs: Any) -> None:
572
573
573
574
def transition (
574
575
field : FSMFieldMixin ,
575
- source : str | int | Sequence [str | int ] | State = "*" ,
576
+ source : str | int | Sequence [str | int ] = "*" ,
576
577
target : str | int | State | None = None ,
577
578
on_error : str | int | None = None ,
578
579
conditions : list [Callable [[Any ], bool ]] = [],
579
580
permission : str | Callable [[models .Model , UserWithPermissions ], bool ] | None = None ,
580
581
custom : dict [str , _StrOrPromise ] = {},
581
- ) -> _ToDo :
582
+ ) -> Callable [[ Any ], Any ] :
582
583
"""
583
584
Method decorator to mark allowed transitions.
584
585
585
586
Set target to None if current state needs to be validated and
586
587
has not changed after the function call.
587
588
"""
588
589
589
- def inner_transition (func ) :
590
+ def inner_transition (func : _ToDo ) -> _ToDo :
590
591
wrapper_installed , fsm_meta = True , getattr (func , "_django_fsm" , None )
591
592
if not fsm_meta :
592
593
wrapper_installed = False
@@ -647,15 +648,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
647
648
648
649
649
650
class State :
650
- def get_state (self , model , transition , result , args = [], kwargs = {}):
651
+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
651
652
raise NotImplementedError
652
653
653
654
654
655
class RETURN_VALUE (State ):
655
656
def __init__ (self , * allowed_states : Sequence [str | int ]) -> None :
656
657
self .allowed_states = allowed_states if allowed_states else None
657
658
658
- def get_state (self , model , transition , result , args = [], kwargs = {}):
659
+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
659
660
if self .allowed_states is not None :
660
661
if result not in self .allowed_states :
661
662
raise InvalidResultState (f"{ result } is not in list of allowed states\n { self .allowed_states } " )
@@ -667,7 +668,9 @@ def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] |
667
668
self .func = func
668
669
self .allowed_states = states
669
670
670
- def get_state (self , model , transition , result , args = [], kwargs = {}):
671
+ def get_state (
672
+ self , model : _Model , transition : Transition , result : _StateValue | Any , args : Any = [], kwargs : Any = {}
673
+ ) -> _ToDo :
671
674
result_state = self .func (model , * args , ** kwargs )
672
675
if self .allowed_states is not None :
673
676
if result_state not in self .allowed_states :
0 commit comments