Skip to content

Commit 27622bc

Browse files
committed
Merge tag 'fix-concurrency' into develop
Fix shared state
2 parents 4cd9926 + 43c711d commit 27622bc

5 files changed

Lines changed: 131 additions & 17 deletions

File tree

statemachine/state.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Optional # noqa: F401, I001
2+
from copy import deepcopy
23

34
from .callbacks import Callbacks
45
from .exceptions import StateMachineError
@@ -9,7 +10,7 @@
910

1011
class State:
1112
"""
12-
A State in a state machine describes a particular behaviour of the machine.
13+
A State in a state machine describes a particular behavior of the machine.
1314
When we say that a machine is “in” a state, it means that the machine behaves
1415
in the way that state describes.
1516
@@ -80,15 +81,27 @@ def __init__(
8081
self.name = name
8182
self.value = value
8283
self._id = None # type: Optional[str]
84+
self._storage = ""
8385
self._initial = initial
8486
self.transitions = TransitionList()
8587
self._final = final
8688
self.enter = Callbacks().add(enter)
8789
self.exit = Callbacks().add(exit)
8890

89-
def _setup(self, resolver):
91+
def __eq__(self, other):
92+
return (
93+
isinstance(other, State) and self.name == other.name and self.id == other.id
94+
)
95+
96+
def __hash__(self):
97+
return hash(repr(self))
98+
99+
def _setup(self, machine, resolver):
100+
self.machine = machine
90101
self.enter.setup(resolver)
91102
self.exit.setup(resolver)
103+
machine.__dict__[self._storage] = self
104+
return self
92105

93106
def _add_observer(self, *resolvers):
94107
for r in resolvers:
@@ -108,7 +121,8 @@ def __repr__(self):
108121
)
109122

110123
def __get__(self, machine, owner):
111-
self.machine = machine
124+
if machine and self._storage in machine.__dict__:
125+
return machine.__dict__[self._storage]
112126
return self
113127

114128
def __set__(self, instance, value):
@@ -118,12 +132,16 @@ def __set__(self, instance, value):
118132
)
119133
)
120134

135+
def clone(self):
136+
return deepcopy(self)
137+
121138
@property
122139
def id(self):
123140
return self._id
124141

125142
def _set_id(self, id):
126143
self._id = id
144+
self._storage = f"_{id}"
127145
if self.value is None:
128146
self.value = id
129147

statemachine/statemachine.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def __init__(self, model=None, state_field="state", start_value=None):
2828
self.state_field = state_field
2929
self.start_value = start_value
3030

31-
initial_transition = Transition(None, None, event="__initial__")
31+
initial_transition = Transition(
32+
None, self._get_initial_state(), event="__initial__"
33+
)
3234
self._setup(initial_transition)
3335
self._activate_initial_state(initial_transition)
3436

@@ -39,21 +41,19 @@ def __repr__(self):
3941
f"current_state={current_state_id!r})"
4042
)
4143

42-
def _activate_initial_state(self, initial_transition):
43-
44+
def _get_initial_state(self):
4445
current_state_value = (
4546
self.start_value if self.start_value else self.initial_state.value
4647
)
47-
if self.current_state_value is None:
48-
49-
try:
50-
initial_state = self.states_map[current_state_value]
51-
except KeyError as err:
52-
raise InvalidStateValue(current_state_value) from err
48+
try:
49+
return self.states_map[current_state_value]
50+
except KeyError as err:
51+
raise InvalidStateValue(current_state_value) from err
5352

53+
def _activate_initial_state(self, initial_transition):
54+
if self.current_state_value is None:
5455
# send an one-time event `__initial__` to enter the current state.
5556
# current_state = self.current_state
56-
initial_transition.target = initial_state
5757
initial_transition.before.clear()
5858
initial_transition.on.clear()
5959
initial_transition.after.clear()
@@ -93,8 +93,18 @@ def _setup(self, initial_transition):
9393
model = ObjectConfig(self.model, skip_attrs={self.state_field})
9494
default_resolver = resolver_factory(machine, model)
9595

96-
initial_transition._setup(default_resolver)
97-
self._visit_states_and_transitions(lambda x: x._setup(default_resolver))
96+
# clone states and transitions to avoid sharing callbacks references between instances
97+
self.states_map = {
98+
state.value: state.clone()._setup(self, default_resolver)
99+
for state in self.states
100+
}
101+
self.states = list(self.states_map.values())
102+
103+
for state in self.states:
104+
for transition in state.transitions:
105+
transition._setup(self, default_resolver)
106+
107+
initial_transition._setup(self, default_resolver)
98108
self.add_observer(machine, model)
99109

100110
def add_observer(self, *observers):

statemachine/transition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ def __repr__(self):
7070
f"internal={self.internal!r})"
7171
)
7272

73-
def _setup(self, resolver):
73+
def _upd_state_refs(self, machine):
74+
if self.source:
75+
self.source = machine.__dict__[self.source._storage]
76+
self.target = machine.__dict__[self.target._storage]
77+
78+
def _setup(self, machine, resolver):
79+
self._upd_state_refs(machine)
7480
self.validators.setup(resolver)
7581
self.cond.setup(resolver)
7682
self.before.setup(resolver)

tests/examples/order_control_rich_model_machine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class OrderControl(StateMachine):
6262
control = OrderControl()
6363
except AttrNotFound as e:
6464
assert ( # noqa: PT017
65-
str(e) == "Did not found name 'payment_received' from model or statemachine"
65+
str(e) == "Did not found name 'wait_for_payment' from model or statemachine"
6666
)
6767

6868
# %%

tests/test_callbacks_isolation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pytest
2+
3+
from statemachine import State
4+
from statemachine import StateMachine
5+
6+
7+
@pytest.fixture()
8+
def simple_sm_cls():
9+
class TestStateMachine(StateMachine):
10+
# States
11+
initial = State("Initial", initial=True)
12+
final = State("Final", final=True, enter="do_enter_final")
13+
14+
finish = initial.to(final, cond="can_finish", on="do_finish")
15+
16+
def __init__(self, name):
17+
self.name = name
18+
self.can_finish = False
19+
self.finalized = False
20+
super(TestStateMachine, self).__init__()
21+
22+
def do_finish(self):
23+
return self.name, self.can_finish
24+
25+
def do_enter_final(self):
26+
self.finalized = True
27+
28+
return TestStateMachine
29+
30+
31+
class TestCallbacksIsolation:
32+
def test_should_conditions_be_isolated(self, simple_sm_cls):
33+
sm1 = simple_sm_cls("sm1")
34+
sm2 = simple_sm_cls("sm2")
35+
sm3 = simple_sm_cls("sm3")
36+
37+
sm1.can_finish = True
38+
assert sm1.initial.transitions[0].cond.call() == [True]
39+
assert sm2.initial.transitions[0].cond.call() == [False]
40+
assert sm3.initial.transitions[0].cond.call() == [False]
41+
42+
def test_should_actions_be_isolated(self, simple_sm_cls):
43+
sm1 = simple_sm_cls("sm1")
44+
sm2 = simple_sm_cls("sm2")
45+
46+
sm1.can_finish = True
47+
sm2.can_finish = True
48+
49+
sm1_initial = sm1.initial
50+
sm1_final = sm1.final
51+
52+
assert sm2.finish() == ("sm2", True)
53+
54+
assert not sm2.initial.is_active
55+
assert sm2.final.is_active
56+
assert sm2.finalized is True
57+
58+
assert sm1_initial.is_active
59+
assert not sm1_final.is_active
60+
assert sm1.finalized is False
61+
62+
assert sm1.initial.is_active
63+
assert not sm1.final.is_active
64+
65+
assert sm1.finish() == ("sm1", True)
66+
67+
assert sm1.finalized is True
68+
assert not sm1.initial.is_active
69+
assert sm1.final.is_active
70+
71+
def test_instance_states_and_transitions_are_isolated(self, simple_sm_cls):
72+
sm1 = simple_sm_cls("sm1")
73+
74+
assert sm1.initial == simple_sm_cls.initial
75+
assert sm1.initial is not simple_sm_cls.initial
76+
77+
assert repr(sm1.initial.transitions[0]) == repr(
78+
simple_sm_cls.initial.transitions[0]
79+
)
80+
assert sm1.initial.transitions[0] is not simple_sm_cls.initial.transitions[0]

0 commit comments

Comments
 (0)