Skip to content

Commit e25a88d

Browse files
committed
feat: Add syntax for compound and parallel states (only parser)
1 parent 9c62c8d commit e25a88d

16 files changed

Lines changed: 421 additions & 138 deletions

docs/diagram.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Graphviz. For example, on Debian-based systems (such as Ubuntu), you can use the
4242
>>> dot = graph()
4343

4444
>>> dot.to_string() # doctest: +ELLIPSIS
45-
'digraph list {...
45+
'digraph OrderControl {...
4646

4747
```
4848

1.94 KB
Loading
1.22 KB
Loading
2.51 KB
Loading
438 Bytes
Loading

statemachine/contrib/diagram.py

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,62 +18,79 @@ class DotGraphMachine:
1818
font_name = "Arial"
1919
"""Graph font face name"""
2020

21-
state_font_size = "10"
22-
"""State font size in points"""
21+
state_font_size = "10pt"
22+
"""State font size"""
2323

2424
state_active_penwidth = 2
2525
"""Active state external line width"""
2626

2727
state_active_fillcolor = "turquoise"
2828

29-
transition_font_size = "9"
30-
"""Transition font size in points"""
29+
transition_font_size = "9pt"
30+
"""Transition font size"""
3131

32-
def __init__(self, machine: StateMachine):
32+
def __init__(self, machine):
3333
self.machine = machine
3434

35-
def _get_graph(self):
36-
machine = self.machine
35+
def _get_graph(self, machine):
3736
return pydot.Dot(
38-
"list",
37+
machine.name,
3938
graph_type="digraph",
4039
label=machine.name,
4140
fontname=self.font_name,
4241
fontsize=self.state_font_size,
4342
rankdir=self.graph_rankdir,
43+
compound="true",
4444
)
4545

46-
def _initial_node(self):
46+
def _get_subgraph(self, state):
47+
style = ", solid"
48+
if state.parent and state.parent.parallel:
49+
style = ", dashed"
50+
subgraph = pydot.Subgraph(
51+
label=f"{state.name}",
52+
graph_name=f"cluster_{state.id}",
53+
style=f"rounded{style}",
54+
cluster="true",
55+
)
56+
return subgraph
57+
58+
def _initial_node(self, state):
4759
node = pydot.Node(
48-
"i",
49-
shape="circle",
60+
self._state_id(state),
61+
label="",
62+
shape="point",
5063
style="filled",
51-
fontsize="1",
64+
fontsize="1pt",
5265
fixedsize="true",
5366
width=0.2,
5467
height=0.2,
5568
)
5669
node.set_fillcolor("black")
5770
return node
5871

59-
def _initial_edge(self):
72+
def _initial_edge(self, initial_node, state):
73+
extra_params = {}
74+
if state.states:
75+
extra_params["lhead"] = f"cluster_{state.id}"
6076
return pydot.Edge(
61-
"i",
62-
self.machine.initial_state.id,
77+
initial_node.get_name(),
78+
self._state_id(state),
6379
label="",
6480
color="blue",
6581
fontname=self.font_name,
6682
fontsize=self.transition_font_size,
83+
**extra_params,
6784
)
6885

6986
def _actions_getter(self):
7087
if isinstance(self.machine, StateMachine):
7188

72-
def getter(grouper) -> str:
89+
def getter(grouper):
7390
return self.machine._callbacks.str(grouper.key)
7491
else:
7592

76-
def getter(grouper) -> str:
93+
def getter(grouper):
7794
all_names = set(dir(self.machine))
7895
return ", ".join(
7996
str(c) for c in grouper if not c.is_convention or c.func in all_names
@@ -104,11 +121,18 @@ def _state_actions(self, state):
104121

105122
return actions
106123

124+
@staticmethod
125+
def _state_id(state):
126+
if state.states:
127+
return f"{state.id}_anchor"
128+
else:
129+
return state.id
130+
107131
def _state_as_node(self, state):
108132
actions = self._state_actions(state)
109133

110134
node = pydot.Node(
111-
state.id,
135+
self._state_id(state),
112136
label=f"{state.name}{actions}",
113137
shape="rectangle",
114138
style="rounded, filled",
@@ -127,29 +151,64 @@ def _transition_as_edge(self, transition):
127151
cond = ", ".join([str(cond) for cond in transition.cond])
128152
if cond:
129153
cond = f"\n[{cond}]"
154+
155+
extra_params = {}
156+
has_substates = transition.source.states or transition.target.states
157+
if transition.source.states:
158+
extra_params["ltail"] = f"cluster_{transition.source.id}"
159+
if transition.target.states:
160+
extra_params["lhead"] = f"cluster_{transition.target.id}"
161+
130162
return pydot.Edge(
131-
transition.source.id,
132-
transition.target.id,
163+
self._state_id(transition.source),
164+
self._state_id(transition.target),
133165
label=f"{transition.event}{cond}",
134166
color="blue",
135167
fontname=self.font_name,
136168
fontsize=self.transition_font_size,
169+
minlen=2 if has_substates else 1,
170+
**extra_params,
137171
)
138172

139173
def get_graph(self):
140-
graph = self._get_graph()
141-
graph.add_node(self._initial_node())
142-
graph.add_edge(self._initial_edge())
174+
graph = self._get_graph(self.machine)
175+
self._graph_states(self.machine, graph)
176+
return graph
143177

144-
for state in self.machine.states:
145-
graph.add_node(self._state_as_node(state))
146-
for transition in state.transitions:
178+
def _graph_states(self, state, graph):
179+
initial_node = self._initial_node(state)
180+
initial_subgraph = pydot.Subgraph(
181+
graph_name=f"{initial_node.get_name()}_initial",
182+
label="",
183+
peripheries=0,
184+
margin=0,
185+
)
186+
atomic_states_subgraph = pydot.Subgraph(
187+
graph_name=f"cluster_{initial_node.get_name()}_atomic",
188+
label="",
189+
peripheries=0,
190+
cluster="true",
191+
)
192+
initial_subgraph.add_node(initial_node)
193+
graph.add_subgraph(initial_subgraph)
194+
graph.add_subgraph(atomic_states_subgraph)
195+
196+
initial = next(s for s in state.states if s.initial)
197+
graph.add_edge(self._initial_edge(initial_node, initial))
198+
199+
for substate in state.states:
200+
if substate.states:
201+
subgraph = self._get_subgraph(substate)
202+
self._graph_states(substate, subgraph)
203+
graph.add_subgraph(subgraph)
204+
else:
205+
atomic_states_subgraph.add_node(self._state_as_node(substate))
206+
207+
for transition in substate.transitions:
147208
if transition.internal:
148209
continue
149210
graph.add_edge(self._transition_as_edge(transition))
150211

151-
return graph
152-
153212
def __call__(self):
154213
return self.get_graph()
155214

@@ -165,7 +224,11 @@ def quickchart_write_svg(sm: StateMachine, path: str):
165224
>>> from tests.examples.order_control_machine import OrderControl
166225
>>> sm = OrderControl()
167226
>>> print(sm._graph().to_string())
168-
digraph list {
227+
digraph OrderControl {
228+
label=OrderControl;
229+
fontname=Arial;
230+
fontsize="10pt";
231+
rankdir=LR;
169232
...
170233
171234
To give you an example, we included this method that will serialize the dot, request the graph

statemachine/factory.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from . import registry
99
from .event import Event
1010
from .exceptions import InvalidDefinition
11+
from .graph import iterate_states
1112
from .graph import iterate_states_and_transitions
1213
from .graph import visit_connected_states
1314
from .i18n import _
@@ -30,6 +31,7 @@ def __init__(
3031
super().__init__(name, bases, attrs)
3132
registry.register(cls)
3233
cls.name = cls.__name__
34+
cls.id = cls.name.lower()
3335
cls.states: States = States()
3436
cls.states_map: Dict[Any, State] = {}
3537
"""Map of ``state.value`` to the corresponding :ref:`state`."""
@@ -42,12 +44,30 @@ def __init__(
4244

4345
cls.add_inherited(bases)
4446
cls.add_from_attributes(attrs)
47+
cls._unpack_builders_callbacks()
4548
cls._update_event_references()
4649

47-
try:
48-
cls.initial_state: State = next(s for s in cls.states if s.initial)
49-
except StopIteration:
50-
cls.initial_state = None # Abstract SM still don't have states
50+
if not cls.states:
51+
return
52+
53+
cls._initials_by_document_order(cls.states)
54+
55+
initials = [s for s in cls.states if s.initial]
56+
parallels = [s.id for s in cls.states if s.parallel]
57+
root_only_has_parallels = len(cls.states) == len(parallels)
58+
59+
if len(initials) != 1 and not root_only_has_parallels:
60+
raise InvalidDefinition(
61+
_(
62+
"There should be one and only one initial state. "
63+
"Your currently have these: {0}"
64+
).format(", ".join(s.id for s in initials))
65+
)
66+
67+
if initials:
68+
cls.initial_state = initials[0]
69+
else:
70+
cls.initial_state = None # TODO: Check if still enter here for abstract SM
5171

5272
cls.final_states: List[State] = [state for state in cls.states if state.final]
5373

@@ -59,6 +79,26 @@ def __init__(
5979

6080
def __getattr__(self, attribute: str) -> Any: ...
6181

82+
def _initials_by_document_order(cls, states):
83+
"""Set initial state by document order if no explicit initial state is set"""
84+
has_initial = False
85+
for s in states:
86+
cls._initials_by_document_order(s.states)
87+
if s.initial:
88+
has_initial = True
89+
break
90+
if not has_initial and states:
91+
states[0]._initial = True
92+
93+
def _unpack_builders_callbacks(cls):
94+
callbacks = {}
95+
for state in iterate_states(cls.states):
96+
if state._callbacks:
97+
callbacks.update(state._callbacks)
98+
del state._callbacks
99+
for key, value in callbacks.items():
100+
setattr(cls, key, value)
101+
62102
def _check(cls):
63103
has_states = bool(cls.states)
64104
cls._abstract = not has_states
@@ -82,8 +122,9 @@ def _check_initial_state(cls):
82122
"You currently have these: {!r}"
83123
).format([s.id for s in initials])
84124
)
85-
if not initials[0].transitions.transitions:
86-
raise InvalidDefinition(_("There are no transitions."))
125+
# TODO: Check if this is still needed
126+
# if not initials[0].transitions.transitions:
127+
# raise InvalidDefinition(_("There are no transitions."))
87128

88129
def _check_final_states(cls):
89130
final_state_with_invalid_transitions = [
@@ -205,15 +246,19 @@ def _add_unbounded_callback(cls, attr_name, func):
205246

206247
def add_state(cls, id, state: State):
207248
state._set_id(id)
208-
cls.states.append(state)
209-
cls.states_map[state.value] = state
210-
if not hasattr(cls, id):
211-
setattr(cls, id, state)
249+
if not state.parent:
250+
cls.states.append(state)
251+
cls.states_map[state.value] = state
252+
if not hasattr(cls, id):
253+
setattr(cls, id, state)
212254

213255
# also register all events associated directly with transitions
214256
for event in state.transitions.unique_events:
215257
cls.add_event(event)
216258

259+
for substate in state.states:
260+
cls.add_state(substate.id, substate)
261+
217262
def add_event(
218263
cls,
219264
event: Event,

statemachine/graph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,20 @@ def visit_connected_states(state):
1111
continue
1212
already_visited.add(state)
1313
yield state
14+
visit.extend(s for s in state.states if s.initial)
1415
visit.extend(t.target for t in state.transitions)
1516

1617

1718
def iterate_states_and_transitions(states):
1819
for state in states:
1920
yield state
2021
yield from state.transitions
22+
if state.states:
23+
yield from iterate_states_and_transitions(state.states)
24+
25+
26+
def iterate_states(states):
27+
for state in states:
28+
yield state
29+
if state.states:
30+
yield from iterate_states(state.states)

0 commit comments

Comments
 (0)