Skip to content

Commit 11c6689

Browse files
authored
feat: Allow using events on callbacks (#355)
1 parent 9a64f4b commit 11c6689

9 files changed

Lines changed: 192 additions & 72 deletions

File tree

docs/actions.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ Use the `enter` or `exit` params available on the `State` constructor.
133133

134134
```
135135

136+
```{hint}
137+
It's also possible to use an event name as action.
138+
139+
**Be careful to not introduce recursion errors** that will raise `RecursionError` exception.
140+
```
141+
136142
### Bind state actions using decorator syntax
137143

138144

@@ -213,6 +219,12 @@ model, using the patterns:
213219

214220
```
215221

222+
```{hint}
223+
It's also possible to use an event name as action to chain transitions.
224+
225+
**Be careful to not introduce recursion errors**, like `loop = initial.to.itself(after="loop")`, that will raise `RecursionError` exception.
226+
```
227+
216228
### Bind event actions using decorator syntax
217229

218230
The action will be registered for every {ref}`transition` associated with the event.

docs/releases/2.0.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ See {ref}`internal transition` for more details.
7575
guards using decorators is now possible.
7676
- [#331](https://github.com/fgmacedo/python-statemachine/pull/331): Added a way to generate diagrams using [QuickChart.io](https://quickchart.io) instead of GraphViz. See {ref}`diagrams` for more details.
7777
- [#353](https://github.com/fgmacedo/python-statemachine/pull/353): Support for abstract state machine classes, so you can subclass `StateMachine` to add behavior on your own base class. Abstract `StateMachine` cannot be instantiated.
78+
- [#355](https://github.com/fgmacedo/python-statemachine/pull/355): Now is possible to trigger an event as an action by registering the event name as the callback param.
7879

7980
## Bugfixes in 2.0
8081

statemachine/dispatcher.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _get_func_by_attr(attr, *configs):
3737
return func, config.obj
3838

3939

40-
def ensure_callable(attr, *objects):
40+
def ensure_callable(attr, *objects): # noqa: C901
4141
"""Ensure that `attr` is a callable, if not, tries to retrieve one from any of the given
4242
`objects`.
4343
@@ -66,6 +66,15 @@ def wrapper(*args, **kwargs):
6666

6767
return wrapper
6868

69+
if getattr(func, "_is_sm_event", False):
70+
"Events already have the 'machine' parameter defined."
71+
72+
def wrapper(*args, **kwargs):
73+
kwargs.pop("machine")
74+
return func(*args, **kwargs)
75+
76+
return wrapper
77+
6978
return SignatureAdapter.wrap(func)
7079

7180

statemachine/event.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .event_data import EventData
2+
from .event_data import TriggerData
23
from .exceptions import TransitionNotAllowed
34

45

@@ -13,38 +14,36 @@ def __call__(self, machine, *args, **kwargs):
1314
return self.trigger(machine, *args, **kwargs)
1415

1516
def trigger(self, machine, *args, **kwargs):
16-
event_data = EventData(machine, self.name, *args, **kwargs)
17-
1817
def trigger_wrapper():
1918
"""Wrapper that captures event_data as closure."""
20-
return self._trigger(event_data)
19+
trigger_data = TriggerData(
20+
machine=machine,
21+
event=self.name,
22+
args=args,
23+
kwargs=kwargs,
24+
)
25+
return self._trigger(trigger_data)
2126

2227
return machine._process(trigger_wrapper)
2328

24-
def _trigger(self, event_data):
25-
event_data.source = event_data.machine.current_state
26-
event_data.state = event_data.machine.current_state
27-
event_data.model = event_data.machine.model
28-
29-
try:
30-
self._process(event_data)
31-
except Exception as error:
32-
event_data.error = error
33-
# TODO: Log errors
34-
# TODO: Allow exception handlers
35-
raise
29+
def _trigger(self, trigger_data: TriggerData):
30+
event_data = self._process(trigger_data)
3631
return event_data.result
3732

38-
def _process(self, event_data):
39-
for transition in event_data.source.transitions:
40-
if not transition.match(event_data.event):
33+
def _process(self, trigger_data: TriggerData):
34+
state = trigger_data.machine.current_state
35+
for transition in state.transitions:
36+
if not transition.match(trigger_data.event):
4137
continue
42-
event_data._set_transition(transition)
38+
39+
event_data = EventData(trigger_data=trigger_data, transition=transition)
4340
if transition.execute(event_data):
4441
event_data.executed = True
4542
break
4643
else:
47-
raise TransitionNotAllowed(event_data.event, event_data.state)
44+
raise TransitionNotAllowed(trigger_data.event, state)
45+
46+
return event_data
4847

4948

5049
def trigger_event_factory(event):
@@ -56,5 +55,6 @@ def trigger_event(self, *args, **kwargs):
5655

5756
trigger_event.name = event
5857
trigger_event.identifier = event
58+
trigger_event._is_sm_event = True
5959

6060
return trigger_event

statemachine/event_data.py

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,77 @@
1+
from dataclasses import dataclass
2+
from dataclasses import field
13
from typing import TYPE_CHECKING
4+
from typing import Any
25

36
if TYPE_CHECKING:
7+
from .state import State
48
from .statemachine import StateMachine
59
from .transition import Transition
610

711

12+
@dataclass
13+
class TriggerData:
14+
machine: "StateMachine"
15+
event: str
16+
"""The Event that was triggered."""
17+
18+
model: Any = field(init=False)
19+
"""A reference to the underlying model that holds the current State."""
20+
21+
args: tuple = field(default_factory=tuple)
22+
"""All positional arguments provided on the Event."""
23+
24+
kwargs: dict = field(default_factory=dict)
25+
"""All keyword arguments provided on the Event."""
26+
27+
def __post_init__(self):
28+
self.model = self.machine.model
29+
30+
31+
@dataclass
832
class EventData:
9-
def __init__(self, machine: "StateMachine", event: str, *args, **kwargs):
10-
self.machine = machine
11-
self.event = event
12-
self.source = kwargs.get("source", None)
13-
self.state = kwargs.get("state", None)
14-
self.model = kwargs.get("model", None)
15-
self.executed = False
16-
self.transition: Transition | None = None
17-
self.target = None
18-
self._set_transition(kwargs.get("transition", None))
19-
20-
# runtime and error
21-
self.args = args
22-
self.kwargs = kwargs
23-
self.error = None
24-
self.result = None
25-
26-
def __repr__(self):
27-
return f"{type(self).__name__}({self.__dict__!r})"
28-
29-
def _set_transition(self, transition: "Transition"):
30-
self.transition = transition
31-
self.target = getattr(transition, "target", None)
33+
trigger_data: TriggerData
34+
transition: "Transition"
35+
"""The Transition instance that was activated by the Event."""
36+
37+
state: "State" = field(init=False)
38+
"""The current State of the state machine."""
39+
40+
source: "State" = field(init=False)
41+
"""The State the state machine was in when the Event started."""
42+
43+
target: "State" = field(init=False)
44+
"""The destination State of the transition."""
45+
46+
result: "Any | None" = None
47+
executed: bool = False
48+
49+
def __post_init__(self):
50+
self.state = self.transition.source
51+
self.source = self.transition.source
52+
self.target = self.transition.target
53+
54+
@property
55+
def machine(self):
56+
return self.trigger_data.machine
57+
58+
@property
59+
def event(self):
60+
return self.trigger_data.event
61+
62+
@property
63+
def args(self):
64+
return self.trigger_data.args
3265

3366
@property
3467
def extended_kwargs(self):
35-
kwargs = self.kwargs.copy()
68+
kwargs = self.trigger_data.kwargs.copy()
3669
kwargs["event_data"] = self
37-
kwargs["event"] = self.event
38-
kwargs["source"] = self.source
39-
kwargs["state"] = self.state
40-
kwargs["model"] = self.model
70+
kwargs["machine"] = self.trigger_data.machine
71+
kwargs["event"] = self.trigger_data.event
72+
kwargs["model"] = self.trigger_data.model
4173
kwargs["transition"] = self.transition
74+
kwargs["state"] = self.state
75+
kwargs["source"] = self.source
4276
kwargs["target"] = self.target
4377
return kwargs

statemachine/statemachine.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .dispatcher import ObjectConfig
44
from .dispatcher import resolver_factory
55
from .event import Event
6+
from .event_data import TriggerData
67
from .event_data import EventData
78
from .exceptions import InvalidStateValue
89
from .exceptions import InvalidDefinition
@@ -62,30 +63,29 @@ def _activate_initial_state(self, initial_transition):
6263
initial_transition.before.clear()
6364
initial_transition.on.clear()
6465
initial_transition.after.clear()
66+
6567
event_data = EventData(
66-
self,
67-
initial_transition.event,
68+
trigger_data=TriggerData(
69+
machine=self,
70+
event=initial_transition.event,
71+
),
6872
transition=initial_transition,
6973
)
7074
self._activate(event_data)
7175

7276
def _get_protected_attrs(self):
73-
return (
74-
{
75-
"_abstract",
76-
"model",
77-
"state_field",
78-
"start_value",
79-
"initial_state",
80-
"final_states",
81-
"states",
82-
"_events",
83-
"states_map",
84-
"send",
85-
}
86-
| {s.id for s in self.states}
87-
| set(self._events.keys())
88-
)
77+
return {
78+
"_abstract",
79+
"model",
80+
"state_field",
81+
"start_value",
82+
"initial_state",
83+
"final_states",
84+
"states",
85+
"_events",
86+
"states_map",
87+
"send",
88+
} | {s.id for s in self.states}
8989

9090
def _visit_states_and_transitions(self, visitor):
9191
for state in self.states:
@@ -165,7 +165,6 @@ def _process(self, trigger):
165165

166166
def _activate(self, event_data: EventData):
167167
transition = event_data.transition
168-
assert transition is not None
169168
source = event_data.state
170169
target = transition.target
171170

statemachine/transition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from functools import partial
2+
from typing import TYPE_CHECKING
23

34
from .callbacks import Callbacks
45
from .callbacks import ConditionWrapper
5-
from .event_data import EventData
66
from .events import Events
77
from .exceptions import InvalidDefinition
88

9+
if TYPE_CHECKING:
10+
from .event_data import EventData
11+
912

1013
class Transition:
1114
"""A transition holds reference to the source and target state.
@@ -119,7 +122,7 @@ def events(self):
119122
def add_event(self, value):
120123
self._events.add(value)
121124

122-
def execute(self, event_data: EventData):
125+
def execute(self, event_data: "EventData"):
123126
self.validators.call(*event_data.args, **event_data.extended_kwargs)
124127
if not self._eval_cond(event_data):
125128
return False

tests/examples/order_control_rich_model_machine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self):
1919
def payments_enough(self, amount):
2020
return sum(self.payments) + amount >= self.order_total
2121

22-
def add_to_order(self, amount):
22+
def before_add_to_order(self, amount):
2323
self.order_total += amount
2424
return self.order_total
2525

@@ -40,7 +40,7 @@ class OrderControl(StateMachine):
4040
shipping = State()
4141
completed = State(final=True)
4242

43-
add_to_order = waiting_for_payment.to(waiting_for_payment, before="add_to_order")
43+
add_to_order = waiting_for_payment.to(waiting_for_payment)
4444
receive_payment = waiting_for_payment.to(
4545
processing, cond="payments_enough"
4646
) | waiting_for_payment.to(waiting_for_payment, unless="payments_enough")

0 commit comments

Comments
 (0)