Skip to content

Commit 43d61fa

Browse files
committed
Step 3
1 parent 8a98c56 commit 43d61fa

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

.mypy.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ check_untyped_defs = True
1515
# get passing if you use a lot of untyped libraries
1616
disallow_subclassing_any = True
1717
disallow_untyped_decorators = True
18-
; disallow_any_generics = True
18+
disallow_any_generics = True
1919

2020
# These next few are various gradations of forcing use of type annotations
21-
; disallow_untyped_calls = True
21+
disallow_untyped_calls = True
2222
; disallow_incomplete_defs = True
2323
; disallow_untyped_defs = True
2424

django_fsm/__init__.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,30 @@
3535

3636
if TYPE_CHECKING:
3737
from collections.abc import Callable
38+
from collections.abc import Generator
3839
from collections.abc import Sequence
3940
from typing import Any
4041

41-
from django.contrib.auth.models import AbstractBaseUser
42+
from django.contrib.auth.models import PermissionsMixin as UserWithPermissions
4243
from django.utils.functional import _StrOrPromise
4344

4445
_Model = models.Model
46+
_Field = models.Field[Any, Any]
47+
CharField = models.CharField[str, str]
48+
IntegerField = models.IntegerField[int, int]
49+
ForeignKey = models.ForeignKey[Any, Any]
4550
else:
4651
_Model = object
52+
_Field = object
53+
CharField = models.CharField
54+
IntegerField = models.IntegerField
55+
ForeignKey = models.ForeignKey
4756

4857

4958
class TransitionNotAllowed(Exception):
5059
"""Raised when a transition is not allowed"""
5160

52-
def __init__(self, *args, **kwargs) -> None:
61+
def __init__(self, *args: Any, **kwargs: Any) -> None:
5362
self.object = kwargs.pop("object", None)
5463
self.method = kwargs.pop("method", None)
5564
super().__init__(*args, **kwargs)
@@ -70,12 +79,12 @@ class ConcurrentTransition(Exception):
7079
class Transition:
7180
def __init__(
7281
self,
73-
method: Callable,
82+
method: Callable[..., Any],
7483
source: str | int | Sequence[str | int] | State,
7584
target: str | int | State | None,
7685
on_error: str | int | None,
7786
conditions: list[Callable[[Any], bool]],
78-
permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None,
87+
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None,
7988
custom: dict[str, _StrOrPromise],
8089
) -> None:
8190
self.method = method
@@ -90,7 +99,7 @@ def __init__(
9099
def name(self) -> str:
91100
return self.method.__name__
92101

93-
def has_perm(self, instance, user) -> bool:
102+
def has_perm(self, instance, user: UserWithPermissions) -> bool:
94103
if not self.permission:
95104
return True
96105
if callable(self.permission):
@@ -113,7 +122,7 @@ def __eq__(self, other):
113122
return False
114123

115124

116-
def get_available_FIELD_transitions(instance, field):
125+
def get_available_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]:
117126
"""
118127
List of transitions available in current model state
119128
with all conditions met
@@ -127,14 +136,16 @@ def get_available_FIELD_transitions(instance, field):
127136
yield meta.get_transition(curr_state)
128137

129138

130-
def get_all_FIELD_transitions(instance, field):
139+
def get_all_FIELD_transitions(instance, field: FSMFieldMixin) -> Generator[Transition, None, None]:
131140
"""
132141
List of all transitions available in current model state
133142
"""
134143
return field.get_all_transitions(instance.__class__)
135144

136145

137-
def get_available_user_FIELD_transitions(instance, user, field):
146+
def get_available_user_FIELD_transitions(
147+
instance, user: UserWithPermissions, field: FSMFieldMixin
148+
) -> Generator[Transition, None, None]:
138149
"""
139150
List of transitions available in current model state
140151
with all conditions met and user have rights on it
@@ -153,15 +164,24 @@ def __init__(self, field, method) -> None:
153164
self.field = field
154165
self.transitions: dict[str, Any] = {} # source -> Transition
155166

156-
def get_transition(self, source):
167+
def get_transition(self, source: str):
157168
transition = self.transitions.get(source, None)
158169
if transition is None:
159170
transition = self.transitions.get("*", None)
160171
if transition is None:
161172
transition = self.transitions.get("+", None)
162173
return transition
163174

164-
def add_transition(self, method, source, target, on_error=None, conditions=[], permission=None, custom={}) -> None:
175+
def add_transition(
176+
self,
177+
method: Callable[..., Any],
178+
source: str,
179+
target: str | int,
180+
on_error: str | int | None = None,
181+
conditions: list[Callable[[Any], bool]] = [],
182+
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
183+
custom: dict[str, _StrOrPromise] = {},
184+
) -> None:
165185
if source in self.transitions:
166186
raise AssertionError(f"Duplicate transition for {source} state")
167187

@@ -204,7 +224,7 @@ def conditions_met(self, instance, state) -> bool:
204224

205225
return all(condition(instance) for condition in transition.conditions)
206226

207-
def has_transition_perm(self, instance, state, user) -> bool:
227+
def has_transition_perm(self, instance, state, user: UserWithPermissions) -> bool:
208228
transition = self.get_transition(state)
209229

210230
if not transition:
@@ -247,10 +267,10 @@ def __set__(self, instance, value) -> None:
247267
self.field.set_state(instance, value)
248268

249269

250-
class FSMFieldMixin(Field):
270+
class FSMFieldMixin(_Field):
251271
descriptor_class = FSMFieldDescriptor
252272

253-
def __init__(self, *args, **kwargs) -> None:
273+
def __init__(self, *args: Any, **kwargs: Any) -> None:
254274
self.protected = kwargs.pop("protected", False)
255275
self.transitions: dict[Any, dict[str, Any]] = {} # cls -> (transitions name -> method)
256276
self.state_proxy = {} # state -> ProxyClsRef
@@ -275,15 +295,15 @@ def deconstruct(self):
275295
kwargs["protected"] = self.protected
276296
return name, path, args, kwargs
277297

278-
def get_state(self, instance):
298+
def get_state(self, instance) -> Any:
279299
# The state field may be deferred. We delegate the logic of figuring this out
280300
# and loading the deferred field on-demand to Django's built-in DeferredAttribute class.
281301
return DeferredAttribute(self).__get__(instance) # type: ignore[attr-defined]
282302

283-
def set_state(self, instance, state):
303+
def set_state(self, instance, state: str) -> None:
284304
instance.__dict__[self.name] = state
285305

286-
def set_proxy(self, instance, state):
306+
def set_proxy(self, instance, state: str) -> None:
287307
"""
288308
Change class
289309
"""
@@ -304,7 +324,7 @@ def set_proxy(self, instance, state):
304324

305325
instance.__class__ = model
306326

307-
def change_state(self, instance, method, *args, **kwargs):
327+
def change_state(self, instance, method, *args: Any, **kwargs: Any):
308328
meta = method._django_fsm
309329
method_name = method.__name__
310330
current_state = self.get_state(instance)
@@ -357,7 +377,7 @@ def change_state(self, instance, method, *args, **kwargs):
357377

358378
return result
359379

360-
def get_all_transitions(self, instance_cls):
380+
def get_all_transitions(self, instance_cls) -> Generator[Transition, None, None]:
361381
"""
362382
Returns [(source, target, name, method)] for all field transitions
363383
"""
@@ -384,7 +404,7 @@ def contribute_to_class(self, cls, name, private_only=False, **kwargs):
384404

385405
class_prepared.connect(self._collect_transitions)
386406

387-
def _collect_transitions(self, *args, **kwargs):
407+
def _collect_transitions(self, *args: Any, **kwargs: Any):
388408
sender = kwargs["sender"]
389409

390410
if not issubclass(sender, self.base_cls):
@@ -413,25 +433,25 @@ def is_field_transition_method(attr):
413433
self.transitions[sender] = sender_transitions
414434

415435

416-
class FSMField(FSMFieldMixin, models.CharField):
436+
class FSMField(FSMFieldMixin, CharField):
417437
"""
418438
State Machine support for Django model as CharField
419439
"""
420440

421-
def __init__(self, *args, **kwargs) -> None:
441+
def __init__(self, *args: Any, **kwargs: Any) -> None:
422442
kwargs.setdefault("max_length", 50)
423443
super().__init__(*args, **kwargs)
424444

425445

426-
class FSMIntegerField(FSMFieldMixin, models.IntegerField):
446+
class FSMIntegerField(FSMFieldMixin, IntegerField):
427447
"""
428448
Same as FSMField, but stores the state value in an IntegerField.
429449
"""
430450

431451
pass
432452

433453

434-
class FSMKeyField(FSMFieldMixin, models.ForeignKey):
454+
class FSMKeyField(FSMFieldMixin, ForeignKey):
435455
"""
436456
State Machine support for Django model
437457
"""
@@ -496,7 +516,7 @@ class ConcurrentTransitionMixin(_Model):
496516
state, thus practically negating their effect.
497517
"""
498518

499-
def __init__(self, *args, **kwargs) -> None:
519+
def __init__(self, *args: Any, **kwargs: Any) -> None:
500520
super().__init__(*args, **kwargs)
501521
self._update_initial_state()
502522

@@ -534,14 +554,14 @@ def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_updat
534554

535555
return updated
536556

537-
def _update_initial_state(self):
557+
def _update_initial_state(self) -> None:
538558
self.__initial_states = {field.attname: field.value_from_object(self) for field in self.state_fields}
539559

540-
def refresh_from_db(self, *args, **kwargs):
560+
def refresh_from_db(self, *args: Any, **kwargs: Any) -> None:
541561
super().refresh_from_db(*args, **kwargs)
542562
self._update_initial_state()
543563

544-
def save(self, *args, **kwargs):
564+
def save(self, *args: Any, **kwargs: Any) -> None:
545565
super().save(*args, **kwargs)
546566
self._update_initial_state()
547567

@@ -552,7 +572,7 @@ def transition(
552572
target: str | int | State | None = None,
553573
on_error: str | int | None = None,
554574
conditions: list[Callable[[Any], bool]] = [],
555-
permission: str | Callable[[models.Model, AbstractBaseUser], bool] | None = None,
575+
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
556576
custom: dict[str, _StrOrPromise] = {},
557577
):
558578
"""
@@ -576,7 +596,7 @@ def inner_transition(func):
576596
func._django_fsm.add_transition(func, source, target, on_error, conditions, permission, custom)
577597

578598
@wraps(func)
579-
def _change_state(instance, *args, **kwargs):
599+
def _change_state(instance, *args: Any, **kwargs: Any):
580600
return fsm_meta.field.change_state(instance, func, *args, **kwargs)
581601

582602
if not wrapper_installed:
@@ -587,7 +607,7 @@ def _change_state(instance, *args, **kwargs):
587607
return inner_transition
588608

589609

590-
def can_proceed(bound_method, check_conditions=True) -> bool:
610+
def can_proceed(bound_method, check_conditions: bool = True) -> bool:
591611
"""
592612
Returns True if model in state allows to call bound_method
593613
@@ -604,7 +624,7 @@ def can_proceed(bound_method, check_conditions=True) -> bool:
604624
return meta.has_transition(current_state) and (not check_conditions or meta.conditions_met(self, current_state))
605625

606626

607-
def has_transition_perm(bound_method, user) -> bool:
627+
def has_transition_perm(bound_method, user: UserWithPermissions) -> bool:
608628
"""
609629
Returns True if model in state allows to call bound_method and user have rights on it
610630
"""
@@ -628,7 +648,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):
628648

629649

630650
class RETURN_VALUE(State):
631-
def __init__(self, *allowed_states) -> None:
651+
def __init__(self, *allowed_states: Sequence[str | int]) -> None:
632652
self.allowed_states = allowed_states if allowed_states else None
633653

634654
def get_state(self, model, transition, result, args=[], kwargs={}):
@@ -639,7 +659,7 @@ def get_state(self, model, transition, result, args=[], kwargs={}):
639659

640660

641661
class GET_STATE(State):
642-
def __init__(self, func, states=None) -> None:
662+
def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] | None = None) -> None:
643663
self.func = func
644664
self.allowed_states = states
645665

0 commit comments

Comments
 (0)