Skip to content

Commit fb64588

Browse files
committed
add state machine
state_machine test state_machine_update
1 parent 6477929 commit fb64588

File tree

1 file changed

+328
-0
lines changed

1 file changed

+328
-0
lines changed
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
from collections.abc import Callable
2+
3+
4+
class State:
5+
def __init__(self, name, on_enter=None, on_exit=None):
6+
self.name = name
7+
self.on_enter = on_enter
8+
self.on_exit = on_exit
9+
10+
def enter(self):
11+
print(f"entering <{self.name}>")
12+
if self.on_enter:
13+
self.on_enter()
14+
15+
def exit(self):
16+
print(f"exiting <{self.name}>")
17+
if self.on_exit:
18+
self.on_exit()
19+
20+
21+
class StateMachine(State):
22+
def __init__(self, model: object, name: str, on_enter=None, on_exit=None):
23+
State.__init__(self, name, on_enter, on_exit)
24+
self.states = {}
25+
self.events = {}
26+
self.transition_table = {}
27+
self._model = model
28+
self._state: StateMachine = None
29+
30+
def add_transition(
31+
self,
32+
src_state: str | State,
33+
event: str,
34+
dst_state: str | State,
35+
guard: str | Callable = None,
36+
action: str | Callable = None,
37+
) -> None:
38+
"""Add a transition to the state machine.
39+
40+
Args:
41+
src_state: Source state name or State object
42+
event: Event name or Event object
43+
dst_state: Destination state name or State object
44+
guard: Guard function name or callable
45+
action: Action function name or callable
46+
"""
47+
# Convert string parameters to objects if necessary
48+
self.register_state(src_state)
49+
self.register_event(event)
50+
self.register_state(dst_state)
51+
52+
def get_state_obj(state):
53+
return state if isinstance(state, State) else self.get_state(state)
54+
55+
def get_callable(func):
56+
return func if callable(func) else getattr(self._model, func, None)
57+
58+
src_state_obj = get_state_obj(src_state)
59+
dst_state_obj = get_state_obj(dst_state)
60+
61+
guard_func = get_callable(guard) if guard else None
62+
action_func = get_callable(action) if action else None
63+
self.transition_table[(src_state_obj.name, event)] = (
64+
dst_state_obj,
65+
guard_func,
66+
action_func,
67+
)
68+
69+
def state_transition(self, src_state: State, event: str):
70+
if (src_state.name, event) not in self.transition_table:
71+
raise ValueError(
72+
f"|{self.name}| invalid transition: <{src_state.name}> : [{event}]"
73+
)
74+
75+
dst_state, guard, action = self.transition_table[(src_state.name, event)]
76+
77+
def call_guard(guard):
78+
if callable(guard):
79+
return guard()
80+
else:
81+
return True
82+
83+
def call_action(action):
84+
if callable(action):
85+
action()
86+
87+
if call_guard(guard):
88+
call_action(action)
89+
if src_state.name != dst_state.name:
90+
print(
91+
f"|{self.name}| transitioning from <{src_state.name}> to <{dst_state.name}>"
92+
)
93+
src_state.exit()
94+
self._state = dst_state
95+
dst_state.enter()
96+
else:
97+
print(
98+
f"|{self.name}| skipping transition from <{src_state.name}> to <{dst_state.name}> because guard failed"
99+
)
100+
101+
def register_state(self, state: str | State, on_enter=None, on_exit=None):
102+
"""Register a state in the state machine.
103+
104+
Args:
105+
state (str | State): The state to register. Can be either a string (state name)
106+
or a State object.
107+
on_enter (Callable, optional): Callback function to be executed when entering the state.
108+
If state is a string and on_enter is None, it will look for
109+
a method named 'on_enter_<state>' in the model.
110+
on_exit (Callable, optional): Callback function to be executed when exiting the state.
111+
If state is a string and on_exit is None, it will look for
112+
a method named 'on_exit_<state>' in the model.
113+
114+
Raises:
115+
ValueError: If a state with the same name is already registered with a different type.
116+
"""
117+
if isinstance(state, str):
118+
if on_enter is None:
119+
on_enter = getattr(self._model, "on_enter_" + state, None)
120+
if on_exit is None:
121+
on_exit = getattr(self._model, "on_exit_" + state, None)
122+
self.states[state] = State(state, on_enter, on_exit)
123+
return
124+
125+
name = state.name
126+
if name in self.states and type(self.states[name]) is not type(state):
127+
raise ValueError(
128+
f'State "{name}" {type(state).__name__} already registered as {type(self.states[name]).__name__}'
129+
)
130+
131+
self.states[name] = state
132+
133+
def register_event(self, event: str):
134+
self.events[event] = event
135+
136+
def get_state(self, name):
137+
return self.states[name]
138+
139+
def get_event(self, name):
140+
return self.events[name]
141+
142+
def has_event(self, event: str):
143+
return event in self.events
144+
145+
def set_current_state(self, state: State | str):
146+
if isinstance(state, str):
147+
self._state = self.get_state(state)
148+
else:
149+
self._state = state
150+
151+
def get_current_state(self):
152+
return self._state
153+
154+
def process(self, event: str) -> None:
155+
"""Process an event in the state machine.
156+
157+
Args:
158+
event: Event name or Event object
159+
"""
160+
if self._state is None:
161+
raise ValueError("State machine is not initialized")
162+
163+
if self.has_event(event):
164+
self.state_transition(self._state, event)
165+
else:
166+
raise ValueError(f"Invalid event: {event}")
167+
168+
169+
class EventBus:
170+
def __init__(self):
171+
self.subscribers = {}
172+
173+
def subscribe(self, event: str, callback: Callable):
174+
if event not in self.subscribers:
175+
self.subscribers[event] = []
176+
self.subscribers[event].append(callback)
177+
178+
def publish(self, event: str):
179+
if event in self.subscribers:
180+
for callback in self.subscribers[event]:
181+
callback()
182+
else:
183+
raise ValueError(f"Invalid event: {event}")
184+
185+
186+
class SlamModel:
187+
def __init__(self, event_bus: EventBus, mapping_success: bool = False):
188+
self.event_bus = event_bus
189+
self.mapping_success = mapping_success
190+
191+
def on_enter_localization(self):
192+
self.event_bus.publish("top_localization_ready_event")
193+
194+
def on_enter_mapping(self):
195+
self.mapping_success = True
196+
197+
def is_mapping_success(self):
198+
return self.mapping_success
199+
200+
201+
class PlanningModel:
202+
def __init__(self, event_bus: EventBus):
203+
self.event_bus = event_bus
204+
205+
def on_exit_working(self):
206+
self.event_bus.publish("top_stop_working_event")
207+
208+
209+
class TopModel:
210+
def __init__(self, event_bus: EventBus):
211+
self.event_bus = event_bus
212+
213+
def on_enter_pre_working(self):
214+
self.event_bus.publish("slam_start_localization_event")
215+
216+
def on_enter_working(self):
217+
self.event_bus.publish("planning_start_working_event")
218+
219+
def on_enter_mapping(self):
220+
self.event_bus.publish("planning_start_remote_control_control_event")
221+
self.event_bus.publish("slam_start_mapping_event")
222+
223+
def on_exit_mapping(self):
224+
self.event_bus.publish("planning_stop_remote_control_control_event")
225+
self.event_bus.publish("slam_stop_mapping_event")
226+
227+
228+
def main():
229+
event_bus = EventBus()
230+
231+
slam_model = SlamModel(event_bus)
232+
planning_model = PlanningModel(event_bus)
233+
top_model = TopModel(event_bus)
234+
235+
slam_machine = StateMachine(slam_model, "slam_machine")
236+
planning_machine = StateMachine(planning_model, "planning_machine")
237+
top_machine = StateMachine(top_model, "top_machine")
238+
239+
# fmt: off
240+
slam_machine.add_transition("idle", "start_localization_event", "localization", "is_mapping_success")
241+
slam_machine.add_transition("localization", "stop_localization_event", "idle")
242+
slam_machine.add_transition("idle", "start_mapping_event", "mapping")
243+
slam_machine.add_transition("mapping", "stop_mapping_event", "idle")
244+
245+
planning_machine.add_transition("idle", "start_working_event", "working")
246+
planning_machine.add_transition("idle", "stop_working_event", "idle")
247+
planning_machine.add_transition("working", "stop_working_event", "idle")
248+
planning_machine.add_transition("idle", "start_remote_control_control_event", "remote_control_control")
249+
planning_machine.add_transition("remote_control_control", "stop_remote_control_control_event", "idle")
250+
251+
top_machine.add_transition("idle", "start_working_event", "pre_working")
252+
top_machine.add_transition("pre_working", "localization_ready_event", "working")
253+
top_machine.add_transition("pre_working", "stop_working_event", "idle")
254+
top_machine.add_transition("working", "stop_working_event", "idle")
255+
top_machine.add_transition("idle", "start_mapping_event", "mapping")
256+
top_machine.add_transition("mapping", "stop_mapping_event", "idle")
257+
# fmt: on
258+
event_bus.subscribe(
259+
"slam_start_localization_event",
260+
lambda: slam_machine.process("start_localization_event"),
261+
)
262+
event_bus.subscribe(
263+
"slam_start_mapping_event",
264+
lambda: slam_machine.process("start_mapping_event"),
265+
)
266+
event_bus.subscribe(
267+
"slam_stop_mapping_event",
268+
lambda: slam_machine.process("stop_mapping_event"),
269+
)
270+
event_bus.subscribe(
271+
"planning_start_working_event",
272+
lambda: planning_machine.process("start_working_event"),
273+
)
274+
event_bus.subscribe(
275+
"planning_start_remote_control_control_event",
276+
lambda: planning_machine.process("start_remote_control_control_event"),
277+
)
278+
event_bus.subscribe(
279+
"planning_stop_remote_control_control_event",
280+
lambda: planning_machine.process("stop_remote_control_control_event"),
281+
)
282+
event_bus.subscribe(
283+
"top_localization_ready_event",
284+
lambda: top_machine.process("localization_ready_event"),
285+
)
286+
event_bus.subscribe(
287+
"top_mapping_ready_event",
288+
lambda: top_machine.process("mapping_ready_event"),
289+
)
290+
event_bus.subscribe(
291+
"top_stop_working_event",
292+
lambda: top_machine.process("stop_working_event"),
293+
)
294+
295+
def working_task():
296+
slam_machine.set_current_state("idle")
297+
planning_machine.set_current_state("idle")
298+
top_machine.set_current_state("idle")
299+
# User sends start working event
300+
top_machine.process("start_working_event")
301+
302+
# Planning Model finish the task, and send stop working event
303+
planning_machine.process("stop_working_event")
304+
305+
print("top_machine: ", top_machine.get_current_state().name)
306+
print("planning_machine: ", planning_machine.get_current_state().name)
307+
print("slam_machine: ", slam_machine.get_current_state().name)
308+
309+
working_task()
310+
311+
def mapping_task():
312+
slam_machine.set_current_state("idle")
313+
planning_machine.set_current_state("idle")
314+
top_machine.set_current_state("idle")
315+
# User sends start mapping event
316+
top_machine.process("start_mapping_event")
317+
# User sends stop mapping event
318+
top_machine.process("stop_mapping_event")
319+
320+
print("top_machine: ", top_machine.get_current_state().name)
321+
print("planning_machine: ", planning_machine.get_current_state().name)
322+
print("slam_machine: ", slam_machine.get_current_state().name)
323+
324+
mapping_task()
325+
326+
327+
if __name__ == "__main__":
328+
main()

0 commit comments

Comments
 (0)