Skip to content

Commit f76a160

Browse files
authored
fix: support States.from_enum() inside compound and parallel states (#606) (#607)
1 parent ec225e0 commit f76a160

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

statemachine/state.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,18 @@ def __new__( # type: ignore [misc]
7171
inherited_kwargs.update(getattr(base, "_factory_kwargs", {}))
7272
inherited_kwargs.update(kwargs)
7373

74+
# Lazy import to avoid circular dependency (states.py imports state.py)
75+
from .states import States
76+
7477
states = []
7578
history = []
7679
callbacks = {}
7780
for key, value in attrs.items():
78-
if isinstance(value, HistoryState):
81+
if isinstance(value, States):
82+
for state_id, state in value.items():
83+
state._set_id(state_id)
84+
states.append(state)
85+
elif isinstance(value, HistoryState):
7986
value._set_id(key)
8087
history.append(value)
8188
elif isinstance(value, State):

tests/test_statechart_compound.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
Theme: Fellowship journey through Middle-earth.
88
"""
99

10+
from enum import Enum
11+
from enum import auto
12+
1013
import pytest
14+
from statemachine.states import States
1115

1216
from statemachine import State
1317
from statemachine import StateChart
@@ -213,3 +217,38 @@ class shire(State.Compound, name="The Shire"):
213217

214218
sm = NamedCompound()
215219
assert sm.shire.name == "The Shire"
220+
221+
async def test_from_enum_inside_compound(self, sm_runner):
222+
"""States.from_enum() works inside compound states (#606)."""
223+
224+
class OuterStates(Enum):
225+
FOO = auto()
226+
BAR = auto()
227+
228+
class InnerStates(Enum):
229+
FIZZ = auto()
230+
BUZZ = auto()
231+
232+
class SC(StateChart):
233+
baz = States.from_enum(OuterStates, initial=OuterStates.FOO, final=OuterStates.BAR)
234+
235+
class inner(State.Compound):
236+
qux = States.from_enum(
237+
InnerStates, initial=InnerStates.FIZZ, final=InnerStates.BUZZ
238+
)
239+
fizz_to_buzz = qux.FIZZ.to(qux.BUZZ)
240+
241+
baz_foo_to_inner = baz.FOO.to(inner)
242+
inner_to_baz_bar = inner.to(baz.BAR)
243+
244+
sm = await sm_runner.start(SC)
245+
assert {OuterStates.FOO} == set(sm.configuration_values)
246+
247+
await sm_runner.send(sm, "baz_foo_to_inner")
248+
assert {"inner", InnerStates.FIZZ} == set(sm.configuration_values)
249+
250+
await sm_runner.send(sm, "fizz_to_buzz")
251+
assert {"inner", InnerStates.BUZZ} == set(sm.configuration_values)
252+
253+
await sm_runner.send(sm, "inner_to_baz_bar")
254+
assert {OuterStates.BAR} == set(sm.configuration_values)

tests/test_statechart_parallel.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
Theme: War of the Ring — multiple simultaneous fronts.
88
"""
99

10+
from enum import Enum
11+
from enum import auto
12+
1013
import pytest
14+
from statemachine.states import States
1115

16+
from statemachine import State
17+
from statemachine import StateChart
1218
from tests.machines.parallel.session import Session
1319
from tests.machines.parallel.session_with_done_state import SessionWithDoneState
1420
from tests.machines.parallel.two_towers import TwoTowers
@@ -139,6 +145,53 @@ async def test_top_level_parallel_done_state_fires_before_termination(self, sm_r
139145
assert {"finished"} == set(sm.configuration_values)
140146
assert sm.is_terminated is True
141147

148+
async def test_from_enum_inside_parallel(self, sm_runner):
149+
"""States.from_enum() works inside parallel states (#606)."""
150+
151+
class RegionA(Enum):
152+
IDLE = auto()
153+
ACTIVE = auto()
154+
155+
class RegionB(Enum):
156+
OFF = auto()
157+
ON = auto()
158+
159+
class SC(StateChart):
160+
start = State(initial=True)
161+
done = State(final=True)
162+
163+
class work(State.Parallel):
164+
class region_a(State.Compound):
165+
a = States.from_enum(RegionA, initial=RegionA.IDLE, final=RegionA.ACTIVE)
166+
go_a = a.IDLE.to(a.ACTIVE)
167+
168+
class region_b(State.Compound):
169+
b = States.from_enum(RegionB, initial=RegionB.OFF, final=RegionB.ON)
170+
go_b = b.OFF.to(b.ON)
171+
172+
begin = start.to(work)
173+
finish = work.to(done)
174+
175+
sm = await sm_runner.start(SC)
176+
assert {"start"} == set(sm.configuration_values)
177+
178+
await sm_runner.send(sm, "begin")
179+
vals = set(sm.configuration_values)
180+
assert "work" in vals
181+
assert RegionA.IDLE in vals
182+
assert RegionB.OFF in vals
183+
184+
await sm_runner.send(sm, "go_a")
185+
vals = set(sm.configuration_values)
186+
assert RegionA.ACTIVE in vals
187+
assert RegionB.OFF in vals # region_b unchanged
188+
189+
await sm_runner.send(sm, "go_b")
190+
# Both regions final -> done.state.work fires
191+
assert {RegionA.ACTIVE, RegionB.ON} <= set(sm.configuration_values) or {"done"} == set(
192+
sm.configuration_values
193+
)
194+
142195
async def test_top_level_parallel_not_terminated_when_one_region_pending(self, sm_runner):
143196
"""Machine keeps running when only one region reaches final."""
144197
sm = await sm_runner.start(Session)

0 commit comments

Comments
 (0)