Skip to content

Commit cb75dd9

Browse files
committed
feat: diagram generation for compound, parallel, and history states
- Render initial pseudo-state (black dot → initial child) inside all compound subgraphs, not just root level - Add history state rendering as UML circles labeled "H" / "H*" - Annotate parallel state subgraph labels with ☷ indicator - Support multi-target transitions (one edge per target) - Extract _add_transitions() helper to reduce _graph_states() complexity - Remove stale TODO comment - Fix NestedStateFactory to propagate kwargs (e.g. parallel=True) from State.Parallel/State.Compound base classes to subclass-created States - Export HistoryState from statemachine.__init__ - Register event label "None" bug as release blocker in PLAN
1 parent 381a2c7 commit cb75dd9

4 files changed

Lines changed: 237 additions & 36 deletions

File tree

statemachine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .event import Event
2+
from .state import HistoryState
23
from .state import State
34
from .statemachine import StateChart
45
from .statemachine import StateMachine
@@ -7,4 +8,4 @@
78
__email__ = "fgmacedo@gmail.com"
89
__version__ = "2.5.0"
910

10-
__all__ = ["StateChart", "StateMachine", "State", "Event"]
11+
__all__ = ["StateChart", "StateMachine", "State", "HistoryState", "Event"]

statemachine/contrib/diagram.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ def _get_subgraph(self, state):
4747
style = ", solid"
4848
if state.parent and state.parent.parallel:
4949
style = ", dashed"
50+
label = state.name
51+
if state.parallel:
52+
label = f"<<b>{state.name}</b> &#9783;>"
5053
subgraph = pydot.Subgraph(
51-
label=f"{state.name}",
54+
label=label,
5255
graph_name=f"cluster_{state.id}",
5356
style=f"rounded{style}",
5457
cluster="true",
@@ -128,6 +131,21 @@ def _state_id(state):
128131
else:
129132
return state.id
130133

134+
def _history_node(self, state):
135+
label = "H*" if state.deep else "H"
136+
return pydot.Node(
137+
self._state_id(state),
138+
label=label,
139+
shape="circle",
140+
style="filled",
141+
fillcolor="white",
142+
fontname=self.font_name,
143+
fontsize="8pt",
144+
fixedsize="true",
145+
width=0.3,
146+
height=0.3,
147+
)
148+
131149
def _state_as_node(self, state):
132150
actions = self._state_actions(state)
133151

@@ -150,41 +168,51 @@ def _state_as_node(self, state):
150168
node.set_fillcolor("white")
151169
return node
152170

153-
def _transition_as_edge(self, transition):
154-
cond = ", ".join([str(cond) for cond in transition.cond])
171+
def _transition_as_edges(self, transition):
172+
targets = transition.targets if transition.targets else [None]
173+
cond = ", ".join([str(c) for c in transition.cond])
155174
if cond:
156175
cond = f"\n[{cond}]"
157176

158-
extra_params = {}
159-
has_substates = transition.source.states or (
160-
transition.target and transition.target.states
161-
)
162-
if transition.source.states:
163-
extra_params["ltail"] = f"cluster_{transition.source.id}"
164-
if transition.target and transition.target.states:
165-
extra_params["lhead"] = f"cluster_{transition.target.id}"
166-
167-
targetless = transition.target is None
168-
return pydot.Edge(
169-
self._state_id(transition.source),
170-
self._state_id(transition.target)
171-
if not targetless
172-
else self._state_id(transition.source),
173-
label=f"{transition.event}{cond}",
174-
color="blue",
175-
fontname=self.font_name,
176-
fontsize=self.transition_font_size,
177-
minlen=2 if has_substates else 1,
178-
**extra_params,
179-
)
177+
edges = []
178+
for i, target in enumerate(targets):
179+
extra_params = {}
180+
has_substates = transition.source.states or (target and target.states)
181+
if transition.source.states:
182+
extra_params["ltail"] = f"cluster_{transition.source.id}"
183+
if target and target.states:
184+
extra_params["lhead"] = f"cluster_{target.id}"
185+
186+
targetless = target is None
187+
label = f"{transition.event}{cond}" if i == 0 else ""
188+
dst = self._state_id(target) if not targetless else self._state_id(transition.source)
189+
edges.append(
190+
pydot.Edge(
191+
self._state_id(transition.source),
192+
dst,
193+
label=label,
194+
color="blue",
195+
fontname=self.font_name,
196+
fontsize=self.transition_font_size,
197+
minlen=2 if has_substates else 1,
198+
**extra_params,
199+
)
200+
)
201+
return edges
180202

181203
def get_graph(self):
182204
graph = self._get_graph(self.machine)
183205
self._graph_states(self.machine, graph, is_root=True)
184206
return graph
185207

208+
def _add_transitions(self, graph, state):
209+
for transition in state.transitions:
210+
if transition.internal:
211+
continue
212+
for edge in self._transition_as_edges(transition):
213+
graph.add_edge(edge)
214+
186215
def _graph_states(self, state, graph, is_root=False):
187-
# TODO: handle parallel states in diagram
188216
initial_node = self._initial_node(state)
189217
initial_subgraph = pydot.Subgraph(
190218
graph_name=f"{initial_node.get_name()}_initial",
@@ -202,9 +230,10 @@ def _graph_states(self, state, graph, is_root=False):
202230
graph.add_subgraph(initial_subgraph)
203231
graph.add_subgraph(atomic_states_subgraph)
204232

205-
if is_root:
206-
initial = next(s for s in state.states if s.initial)
207-
graph.add_edge(self._initial_edge(initial_node, initial))
233+
if state.states and not getattr(state, "parallel", False):
234+
initial = next((s for s in state.states if s.initial), None)
235+
if initial:
236+
graph.add_edge(self._initial_edge(initial_node, initial))
208237

209238
for substate in state.states:
210239
if substate.states:
@@ -213,11 +242,11 @@ def _graph_states(self, state, graph, is_root=False):
213242
graph.add_subgraph(subgraph)
214243
else:
215244
atomic_states_subgraph.add_node(self._state_as_node(substate))
245+
self._add_transitions(graph, substate)
216246

217-
for transition in substate.transitions:
218-
if transition.internal:
219-
continue
220-
graph.add_edge(self._transition_as_edge(transition))
247+
for history_state in getattr(state, "history", []):
248+
atomic_states_subgraph.add_node(self._history_node(history_state))
249+
self._add_transitions(graph, history_state)
221250

222251
def __call__(self):
223252
return self.get_graph()

statemachine/state.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ def __new__( # type: ignore [misc]
5555
cls, classname, bases, attrs, name="", **kwargs
5656
) -> "State":
5757
if not bases:
58-
return super().__new__(cls, classname, bases, attrs) # type: ignore [return-value]
58+
new_cls = super().__new__(cls, classname, bases, attrs) # type: ignore [return-value]
59+
new_cls._factory_kwargs = kwargs # type: ignore [attr-defined]
60+
return new_cls # type: ignore [return-value]
61+
62+
# Inherit factory kwargs from base classes (e.g., parallel=True from State.Parallel)
63+
inherited_kwargs: dict = {}
64+
for base in bases:
65+
inherited_kwargs.update(getattr(base, "_factory_kwargs", {}))
66+
inherited_kwargs.update(kwargs)
5967

6068
states = []
6169
callbacks = {}
@@ -68,7 +76,7 @@ def __new__( # type: ignore [misc]
6876
elif callable(value):
6977
callbacks[key] = value
7078

71-
return State(name=name, states=states, _callbacks=callbacks, **kwargs)
79+
return State(name=name, states=states, _callbacks=callbacks, **inherited_kwargs)
7280

7381
@classmethod
7482
def to(cls, *args: "State", **kwargs) -> "_ToState": # pragma: no cover

tests/test_contrib_diagram.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from statemachine.contrib.diagram import DotGraphMachine
66
from statemachine.contrib.diagram import main
77
from statemachine.contrib.diagram import quickchart_write_svg
8+
from statemachine.transition import Transition
89

10+
from statemachine import HistoryState
911
from statemachine import State
1012
from statemachine import StateChart
1113

@@ -193,3 +195,164 @@ def test_initial_edge_with_compound_state_has_lhead():
193195
edge = graph_maker._initial_edge(initial_node, compound)
194196
attrs = edge.obj_dict["attributes"]
195197
assert attrs.get("lhead") == f"cluster_{compound.id}"
198+
199+
200+
def test_initial_edge_inside_compound_subgraph():
201+
"""Compound substate has an initial edge from dot to initial child."""
202+
203+
class SM(StateChart):
204+
class parent(State.Compound, name="Parent"):
205+
child1 = State(initial=True)
206+
child2 = State(final=True)
207+
208+
go = child1.to(child2)
209+
210+
start = State(initial=True)
211+
end = State(final=True)
212+
213+
enter = start.to(parent)
214+
finish = parent.to(end)
215+
216+
graph = DotGraphMachine(SM)
217+
dot = graph().to_string()
218+
# The compound subgraph should contain an initial point node and an edge to child1
219+
assert "parent_anchor" in dot
220+
assert "child1" in dot
221+
# Verify the initial edge exists (from parent's initial node to child1)
222+
assert "parent_anchor -> child1" in dot
223+
224+
225+
def test_history_state_shallow_diagram():
226+
"""DOT output contains an 'H' circle node for shallow history state."""
227+
h = HistoryState(name="H", deep=False)
228+
h._set_id("h_shallow")
229+
230+
graph_maker = DotGraphMachine.__new__(DotGraphMachine)
231+
graph_maker.font_name = "Arial"
232+
node = graph_maker._history_node(h)
233+
attrs = node.obj_dict["attributes"]
234+
assert attrs["label"] in ("H", '"H"')
235+
assert attrs["shape"] == "circle"
236+
237+
238+
def test_history_state_deep_diagram():
239+
"""DOT output contains an 'H*' circle node for deep history state."""
240+
h = HistoryState(name="H*", deep=True)
241+
h._set_id("h_deep")
242+
243+
graph_maker = DotGraphMachine.__new__(DotGraphMachine)
244+
graph_maker.font_name = "Arial"
245+
node = graph_maker._history_node(h)
246+
# Verify the node renders correctly in DOT output
247+
dot_str = node.to_string()
248+
assert "H*" in dot_str
249+
assert "circle" in dot_str
250+
251+
252+
def test_history_state_default_transition():
253+
"""History state's default transition appears as an edge in the diagram."""
254+
child1 = State("child1", initial=True)
255+
child1._set_id("child1")
256+
child2 = State("child2")
257+
child2._set_id("child2")
258+
259+
h = HistoryState(name="H", deep=False)
260+
h._set_id("hist")
261+
# Add a default transition from history to child1
262+
t = Transition(source=h, target=child1, initial=True)
263+
h.transitions.add_transitions(t)
264+
265+
parent = State("parent", states=[child1, child2], history=[h])
266+
parent._set_id("parent")
267+
268+
graph_maker = DotGraphMachine.__new__(DotGraphMachine)
269+
graph_maker.font_name = "Arial"
270+
graph_maker.transition_font_size = "9pt"
271+
272+
edges = graph_maker._transition_as_edges(t)
273+
assert len(edges) == 1
274+
edge = edges[0]
275+
assert edge.obj_dict["points"] == ("hist", "child1")
276+
277+
278+
def test_parallel_state_label_indicator():
279+
"""Parallel subgraph label includes a visual indicator."""
280+
281+
class SM(StateChart):
282+
validate_disconnected_states: bool = False
283+
284+
class p(State.Parallel, name="p"):
285+
class r1(State.Compound, name="r1"):
286+
a = State(initial=True)
287+
288+
class r2(State.Compound, name="r2"):
289+
b = State(initial=True)
290+
291+
start = State(initial=True)
292+
begin = start.to(p)
293+
294+
graph = DotGraphMachine(SM)
295+
dot = graph().to_string()
296+
# The parallel state label should contain an HTML-like label with the indicator
297+
assert "&#9783;" in dot
298+
299+
300+
def test_multi_target_transition_diagram():
301+
"""Edges are created for all targets of a multi-target transition."""
302+
source = State("source", initial=True)
303+
source._set_id("source")
304+
target1 = State("target1")
305+
target1._set_id("target1")
306+
target2 = State("target2")
307+
target2._set_id("target2")
308+
309+
t = Transition(source=source, target=[target1, target2])
310+
t._events.add("go")
311+
312+
graph_maker = DotGraphMachine.__new__(DotGraphMachine)
313+
graph_maker.font_name = "Arial"
314+
graph_maker.transition_font_size = "9pt"
315+
316+
edges = graph_maker._transition_as_edges(t)
317+
assert len(edges) == 2
318+
assert edges[0].obj_dict["points"] == ("source", "target1")
319+
assert edges[1].obj_dict["points"] == ("source", "target2")
320+
# Only the first edge gets a label
321+
assert edges[0].obj_dict["attributes"]["label"] == "go"
322+
assert edges[1].obj_dict["attributes"]["label"] == ""
323+
324+
325+
def test_compound_and_parallel_mixed():
326+
"""Full diagram with compound and parallel states renders without error."""
327+
328+
class SM(StateChart):
329+
validate_disconnected_states: bool = False
330+
331+
class top(State.Compound, name="Top"):
332+
class par(State.Parallel, name="Par"):
333+
class region1(State.Compound, name="Region1"):
334+
r1_a = State(initial=True)
335+
r1_b = State(final=True)
336+
r1_go = r1_a.to(r1_b)
337+
338+
class region2(State.Compound, name="Region2"):
339+
r2_a = State(initial=True)
340+
r2_b = State(final=True)
341+
r2_go = r2_a.to(r2_b)
342+
343+
entry = State(initial=True)
344+
start_par = entry.to(par)
345+
346+
begin = State(initial=True)
347+
enter_top = begin.to(top)
348+
349+
graph = DotGraphMachine(SM)
350+
dot = graph().to_string()
351+
assert "cluster_top" in dot
352+
assert "cluster_par" in dot
353+
assert "cluster_region1" in dot
354+
assert "cluster_region2" in dot
355+
# Parallel indicator
356+
assert "&#9783;" in dot
357+
# Verify initial edges exist for compound states (top and regions)
358+
assert "top_anchor -> entry" in dot

0 commit comments

Comments
 (0)