Skip to content

Commit 3d07b0a

Browse files
committed
test: Nested states with general syntax
1 parent 71f6277 commit 3d07b0a

5 files changed

Lines changed: 40 additions & 36 deletions

File tree

statemachine/factory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ def _check(cls):
6868
if not cls.states:
6969
raise InvalidDefinition(_("There are no states."))
7070

71-
# TODO: Validate no events if has nested states
72-
# if not cls._events:
73-
# raise InvalidDefinition(_("There are no events."))
71+
if not cls._events:
72+
raise InvalidDefinition(_("There are no events."))
7473

7574
cls._check_disconnected_state()
7675

@@ -117,8 +116,9 @@ def _add_unbounded_callback(cls, attr_name, func):
117116

118117
def add_state(cls, id, state):
119118
state._set_id(id)
120-
cls.states.append(state)
121-
cls.states_map[state.value] = state
119+
if not state.parent:
120+
cls.states.append(state)
121+
cls.states_map[state.value] = state
122122

123123
# also register all events associated directly with transitions
124124
for event in state.transitions.unique_events:

statemachine/state.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Optional # noqa: F401, I001
1+
from typing import Any
2+
from typing import TypeAlias
23
from copy import deepcopy
34

45
from .callbacks import Callbacks
@@ -9,19 +10,22 @@
910

1011

1112
class NestedStateFactory(type):
12-
def __new__(cls, classname, bases, attrs, name=None, initial=False, parallel=False):
13+
def __new__( # type: ignore [misc]
14+
cls, classname, bases, attrs, name=None, initial=False, parallel=False
15+
) -> "State":
1316

1417
if not bases:
15-
return super().__new__(cls, classname, bases, attrs)
18+
return super().__new__(cls, classname, bases, attrs) # type: ignore [return-value]
1619

1720
substates = []
1821
for key, value in attrs.items():
19-
if not isinstance(value, State):
20-
continue
21-
value._set_id(key)
22-
substates.append(value)
22+
if isinstance(value, State):
23+
value._set_id(key)
24+
substates.append(value)
25+
if isinstance(value, TransitionList):
26+
value.add_event(key)
2327

24-
return State(name, initial=initial, parallel=parallel, substates=substates)
28+
return State(name=name, initial=initial, parallel=parallel, substates=substates)
2529

2630

2731
class NestedStateBuilder(metaclass=NestedStateFactory):
@@ -102,37 +106,37 @@ class State:
102106
103107
"""
104108

105-
Builder = NestedStateBuilder
109+
Builder: TypeAlias = NestedStateBuilder
106110

107111
def __init__(
108112
self,
109-
name,
110-
value=None,
111-
initial=False,
112-
final=False,
113-
parallel=False,
114-
substates=None,
115-
enter=None,
116-
exit=None,
113+
name: str = "",
114+
value: Any = None,
115+
initial: bool = False,
116+
final: bool = False,
117+
parallel: bool = False,
118+
substates: Any = None,
119+
enter: Any = None,
120+
exit: Any = None,
117121
):
118-
# type: (str, Optional[Any], bool, bool, bool, Optional[Any], Optional[Any], Optional[Any]) -> None # noqa
119122
self.name = name
120123
self.value = value
121124
self.parallel = parallel
122-
self.parent: "State" = None
123125
self.substates = substates or []
124-
self._id = None # type: Optional[str]
125-
self._storage = ""
126126
self._initial = initial
127-
self.transitions = TransitionList()
128127
self._final = final
128+
self._id: str = ""
129+
self._storage: str = ""
130+
self.parent: "State" = None
131+
self.transitions = TransitionList()
129132
self.enter = Callbacks().add(enter)
130133
self.exit = Callbacks().add(exit)
131134
self._init_substates()
132135

133136
def _init_substates(self):
134137
for substate in self.substates:
135138
substate.parent = self
139+
setattr(self, substate.id, substate)
136140

137141
def __eq__(self, other):
138142
return (
@@ -190,6 +194,8 @@ def _set_id(self, id):
190194
self._storage = f"_{id}"
191195
if self.value is None:
192196
self.value = id
197+
if not self.name:
198+
self.name = self._id.replace("_", " ").capitalize()
193199

194200
def _to_(self, *states, **kwargs):
195201
transitions = TransitionList(

statemachine/transition_list.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections import OrderedDict
2-
31
from .utils import ensure_iterable
42

53

@@ -67,10 +65,9 @@ def add_event(self, event):
6765

6866
@property
6967
def unique_events(self):
70-
# Compat Python2.7: Using OrderedDict to get a unique ordered list
71-
tmp_list = OrderedDict()
68+
tmp_ordered_unique_events_as_keys_on_dict = {}
7269
for transition in self.transitions:
7370
for event in transition.events:
74-
tmp_list[event] = True
71+
tmp_ordered_unique_events_as_keys_on_dict[event] = True
7572

76-
return list(tmp_list.keys())
73+
return list(tmp_ordered_unique_events_as_keys_on_dict.keys())

tests/examples/microwave_inheritance_machine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class on(State.Builder, name="On"):
5050
cooking.to(idle, cond="open.is_active")
5151
cooking.to.itself(internal=True, on="increment_timer")
5252

53+
assert isinstance(on, State) # so mypy stop complaining
5354
turn_off = on.to(off)
5455
turn_on = off.to(on)
5556
on.to(off, cond="cook_time_is_over") # eventless transition

tests/test_compound.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ class engine(State.Builder, name="Engine", initial=True):
1111
off = State("Off", initial=True)
1212
on = State("On")
1313

14-
turn_off = on.to(off)
1514
turn_on = off.to(on)
15+
turn_off = on.to(off)
1616

1717
return TestMachine
1818

@@ -26,8 +26,8 @@ def test_capture_constructor_arguments(self, compound_engine_cls):
2626

2727
def test_list_children_states(self, compound_engine_cls):
2828
sm = compound_engine_cls()
29-
assert [s.id for s in sm.engine.children] == ["off", "on"]
29+
assert [s.id for s in sm.engine.substates] == ["off", "on"]
3030

3131
def test_list_events(self, compound_engine_cls):
3232
sm = compound_engine_cls()
33-
assert [e.name for e in sm.events] == ["turn_off", "turn_on"]
33+
assert [e.name for e in sm.events] == ["turn_on", "turn_off"]

0 commit comments

Comments
 (0)