Skip to content

Commit e16b2f0

Browse files
committed
refactor: enrich diagram IR to separate extraction from rendering
Move domain analysis logic from the renderer (dot.py) into the extractor (extract.py) so the renderer becomes a pure IR→pydot mapping: - Add ActionType enum replacing free strings for DiagramAction.type - Add compound_state_ids and bidirectional_compound_ids to DiagramGraph - Add DiagramTransition.is_initial flag for implicit initial transitions - Remove redundant DiagramTransition.target (use targets list) - Move _collect_compound_ids, _collect_compound_bidir_ids from renderer to extractor - Add _mark_initial_transitions and _resolve_initial_states in extractor - Remove _is_initial_candidate from renderer (use state.is_initial) - Remove implicit transition filtering logic from renderer (use transition.is_initial)
1 parent 127bb12 commit e16b2f0

4 files changed

Lines changed: 108 additions & 69 deletions

File tree

statemachine/contrib/diagram/extract.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import TYPE_CHECKING
22
from typing import List
3+
from typing import Set
34
from typing import Union
45

6+
from .model import ActionType
57
from .model import DiagramAction
68
from .model import DiagramGraph
79
from .model import DiagramState
@@ -54,16 +56,16 @@ def _extract_state_actions(state: "State", getter) -> List[DiagramAction]:
5456
exit_ = str(getter(state.exit))
5557

5658
if entry:
57-
actions.append(DiagramAction(type="entry", body=entry))
59+
actions.append(DiagramAction(type=ActionType.ENTRY, body=entry))
5860
if exit_:
59-
actions.append(DiagramAction(type="exit", body=exit_))
61+
actions.append(DiagramAction(type=ActionType.EXIT, body=exit_))
6062

6163
for transition in state.transitions:
6264
if transition.internal:
6365
on_text = str(getter(transition.on))
6466
if on_text:
6567
actions.append(
66-
DiagramAction(type="internal", body=f"{transition.event} / {on_text}")
68+
DiagramAction(type=ActionType.INTERNAL, body=f"{transition.event} / {on_text}")
6769
)
6870

6971
return actions
@@ -105,14 +107,12 @@ def _extract_transitions_from_state(state: "State", getter) -> List[DiagramTrans
105107
for transition in state.transitions:
106108
targets = transition.targets if transition.targets else []
107109
target_ids = [t.id for t in targets]
108-
primary_target = target_ids[0] if target_ids else None
109110

110111
cond_strs = [str(c) for c in transition.cond]
111112

112113
result.append(
113114
DiagramTransition(
114115
source=transition.source.id,
115-
target=primary_target,
116116
targets=target_ids,
117117
event=transition.event,
118118
guards=cond_strs,
@@ -136,6 +136,79 @@ def _extract_all_transitions(states, getter) -> List[DiagramTransition]:
136136
return result
137137

138138

139+
def _collect_compound_ids(states: List[DiagramState]) -> Set[str]:
140+
"""Collect IDs of states that have children (compound/parallel)."""
141+
result: Set[str] = set()
142+
for state in states:
143+
if state.children:
144+
result.add(state.id)
145+
result.update(_collect_compound_ids(state.children))
146+
return result
147+
148+
149+
def _collect_bidirectional_compound_ids(
150+
transitions: List[DiagramTransition],
151+
compound_ids: Set[str],
152+
) -> Set[str]:
153+
"""Find compound states that have both outgoing and incoming explicit edges."""
154+
outgoing: Set[str] = set()
155+
incoming: Set[str] = set()
156+
for t in transitions:
157+
if t.is_internal:
158+
continue
159+
# Skip implicit initial transitions
160+
if t.source in compound_ids and not t.event and t.targets:
161+
continue
162+
if t.source in compound_ids:
163+
outgoing.add(t.source)
164+
for target_id in t.targets:
165+
if target_id in compound_ids:
166+
incoming.add(target_id)
167+
return outgoing & incoming
168+
169+
170+
def _mark_initial_transitions(
171+
transitions: List[DiagramTransition],
172+
compound_ids: Set[str],
173+
) -> None:
174+
"""Mark implicit initial transitions (compound state → child, no event)."""
175+
for t in transitions:
176+
if t.source in compound_ids and not t.event and t.targets and not t.is_internal:
177+
t.is_initial = True
178+
179+
180+
def _resolve_initial_states(states: List[DiagramState]) -> None:
181+
"""Ensure exactly one state per level has is_initial=True.
182+
183+
Skips parallel areas and history states. Falls back to document order
184+
(first non-history, non-parallel-area state) when no explicit initial exists.
185+
Recurses into children.
186+
187+
Parallel areas (children of a parallel state) have their is_initial flag
188+
cleared: all regions are auto-activated, so no initial arrow is needed.
189+
"""
190+
# Clear is_initial on parallel areas — all children of a parallel state
191+
# are simultaneously active; initial arrows would be misleading.
192+
for s in states:
193+
if s.is_parallel_area:
194+
s.is_initial = False
195+
196+
candidates = [
197+
s
198+
for s in states
199+
if s.type not in (StateType.HISTORY_SHALLOW, StateType.HISTORY_DEEP)
200+
and not s.is_parallel_area
201+
]
202+
203+
has_explicit_initial = any(s.is_initial for s in candidates)
204+
if not has_explicit_initial and candidates:
205+
candidates[0].is_initial = True
206+
207+
for state in states:
208+
if state.children:
209+
_resolve_initial_states(state.children)
210+
211+
139212
def extract(machine_or_class: "MachineRef") -> DiagramGraph:
140213
"""Extract a DiagramGraph IR from a state machine instance or class.
141214
@@ -171,8 +244,15 @@ class itself thanks to the metaclass. Active-state highlighting is only
171244

172245
transitions = _extract_all_transitions(machine.states, getter)
173246

247+
compound_ids = _collect_compound_ids(states)
248+
bidir_ids = _collect_bidirectional_compound_ids(transitions, compound_ids)
249+
_mark_initial_transitions(transitions, compound_ids)
250+
_resolve_initial_states(states)
251+
174252
return DiagramGraph(
175253
name=machine.name,
176254
states=states,
177255
transitions=transitions,
256+
compound_state_ids=compound_ids,
257+
bidirectional_compound_ids=bidir_ids,
178258
)

statemachine/contrib/diagram/model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import field
33
from enum import Enum
44
from typing import List
5-
from typing import Optional
5+
from typing import Set
66

77

88
class StateType(Enum):
@@ -19,9 +19,15 @@ class StateType(Enum):
1919
TERMINATE = "terminate"
2020

2121

22+
class ActionType(Enum):
23+
ENTRY = "entry"
24+
EXIT = "exit"
25+
INTERNAL = "internal"
26+
27+
2228
@dataclass
2329
class DiagramAction:
24-
type: str # "entry", "exit", "internal"
30+
type: ActionType
2531
body: str
2632

2733

@@ -40,16 +46,18 @@ class DiagramState:
4046
@dataclass
4147
class DiagramTransition:
4248
source: str
43-
target: Optional[str]
4449
targets: List[str] = field(default_factory=list)
4550
event: str = ""
4651
guards: List[str] = field(default_factory=list)
4752
actions: List[str] = field(default_factory=list)
4853
is_internal: bool = False
54+
is_initial: bool = False
4955

5056

5157
@dataclass
5258
class DiagramGraph:
5359
name: str
5460
states: List[DiagramState] = field(default_factory=list)
5561
transitions: List[DiagramTransition] = field(default_factory=list)
62+
compound_state_ids: Set[str] = field(default_factory=set)
63+
bidirectional_compound_ids: Set[str] = field(default_factory=set)

statemachine/contrib/diagram/renderers/dot.py

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pydot
99

10+
from ..model import ActionType
1011
from ..model import DiagramAction
1112
from ..model import DiagramGraph
1213
from ..model import DiagramState
@@ -50,39 +51,12 @@ def __init__(self, config: Optional[DotRendererConfig] = None):
5051

5152
def render(self, graph: DiagramGraph) -> pydot.Dot:
5253
"""Render a DiagramGraph to a pydot.Dot object."""
53-
self._collect_compound_ids(graph.states)
54-
self._compound_bidir_ids = self._collect_compound_bidir_ids(graph.transitions)
54+
self._compound_ids = graph.compound_state_ids
55+
self._compound_bidir_ids = graph.bidirectional_compound_ids
5556
dot = self._create_graph(graph.name)
5657
self._render_states(graph.states, graph.transitions, dot)
5758
return dot
5859

59-
def _collect_compound_bidir_ids(self, transitions: List[DiagramTransition]) -> Set[str]:
60-
"""Find compound states that have both outgoing and incoming explicit edges.
61-
62-
Returns the set of compound state IDs that participate in at least one
63-
bidirectional pair, so we can give them separate in/out anchor nodes.
64-
"""
65-
outgoing: Set[str] = set()
66-
incoming: Set[str] = set()
67-
for t in transitions:
68-
if t.is_internal:
69-
continue
70-
if t.source in self._compound_ids and not t.event and t.targets:
71-
continue
72-
if t.source in self._compound_ids:
73-
outgoing.add(t.source)
74-
for target_id in t.targets:
75-
if target_id in self._compound_ids:
76-
incoming.add(target_id)
77-
return outgoing & incoming
78-
79-
def _collect_compound_ids(self, states: List[DiagramState]) -> None:
80-
"""Pre-collect IDs of states that have children (compound/parallel)."""
81-
for state in states:
82-
if state.children:
83-
self._compound_ids.add(state.id)
84-
self._collect_compound_ids(state.children)
85-
8660
def _create_graph(self, name: str) -> pydot.Dot:
8761
cfg = self.config
8862
graph_attrs = {
@@ -151,7 +125,7 @@ def _render_states(
151125
extra_nodes: Optional[List[pydot.Node]] = None,
152126
) -> None:
153127
"""Render states and transitions into the parent graph."""
154-
initial_state = next((s for s in states if self._is_initial_candidate(s, states)), None)
128+
initial_state = next((s for s in states if s.is_initial), None)
155129

156130
# The atomic subgraph groups all non-compound states and the inner
157131
# initial dot (when inside a compound cluster) so Graphviz places them
@@ -255,22 +229,6 @@ def _render_initial_arrow(
255229
)
256230
return added_to_atomic
257231

258-
def _is_initial_candidate(self, state: DiagramState, siblings: List[DiagramState]) -> bool:
259-
"""Check if this state should get an initial arrow."""
260-
# History states don't get initial arrows at this level
261-
if state.type in (StateType.HISTORY_SHALLOW, StateType.HISTORY_DEEP):
262-
return False
263-
# All children of a parallel state are auto-initial; skip initial arrows
264-
if state.is_parallel_area:
265-
return False
266-
# Use the is_initial flag from the model; fall back to document order
267-
if state.is_initial:
268-
return True
269-
has_explicit_initial = any(s.is_initial for s in siblings)
270-
if has_explicit_initial:
271-
return False
272-
return state is siblings[0] if siblings else False
273-
274232
def _create_initial_node(self, node_id: str) -> pydot.Node:
275233
return pydot.Node(
276234
node_id,
@@ -293,7 +251,7 @@ def _create_atomic_node(self, state: DiagramState) -> pydot.Node:
293251
entry/exit actions embed an HTML TABLE (``border="0"``) inside the native
294252
shape to render UML-style compartments (name + separator + actions).
295253
"""
296-
actions = [a for a in state.actions if a.type != "internal" or a.body]
254+
actions = [a for a in state.actions if a.type != ActionType.INTERNAL or a.body]
297255
fillcolor = self.config.state_active_fillcolor if state.is_active else "white"
298256
penwidth = self.config.state_active_penwidth if state.is_active else 2
299257

@@ -363,9 +321,9 @@ def _build_html_table_label(
363321

364322
@staticmethod
365323
def _format_action(action: DiagramAction) -> str:
366-
if action.type == "internal":
324+
if action.type == ActionType.INTERNAL:
367325
return action.body
368-
return f"{action.type} / {action.body}"
326+
return f"{action.type.value} / {action.body}"
369327

370328
def _create_history_node(self, state: DiagramState) -> pydot.Node:
371329
label = "H*" if state.type == StateType.HISTORY_DEEP else "H"
@@ -443,7 +401,7 @@ def _build_compound_label(self, state: DiagramState) -> str:
443401
if state.type == StateType.PARALLEL:
444402
return f"<b>{name}</b> &#9783;"
445403

446-
actions = [a for a in state.actions if a.type != "internal" or a.body]
404+
actions = [a for a in state.actions if a.type != ActionType.INTERNAL or a.body]
447405
if not actions:
448406
return f"<b>{name}</b>"
449407

@@ -465,13 +423,8 @@ def _add_transitions_for_state(
465423
for transition in all_transitions:
466424
if transition.source != state.id or transition.is_internal:
467425
continue
468-
# Skip implicit initial transitions from a compound/parallel state to its
469-
# initial child — these are already represented by the black-dot initial node.
470-
if (
471-
transition.source in self._compound_ids
472-
and not transition.event
473-
and transition.targets
474-
):
426+
# Skip implicit initial transitions — represented by the black-dot initial node.
427+
if transition.is_initial:
475428
continue
476429
for edge in self._create_edges(transition):
477430
graph.add_edge(edge)

tests/test_contrib_diagram.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def test_history_state_default_transition():
318318
"""History state's default transition appears as an edge in the diagram."""
319319
from statemachine.contrib.diagram.model import DiagramTransition
320320

321-
transition = DiagramTransition(source="hist", target="child1", targets=["child1"], event="")
321+
transition = DiagramTransition(source="hist", targets=["child1"], event="")
322322
renderer = DotRenderer()
323323
renderer._compound_ids = set()
324324
edges = renderer._create_edges(transition)
@@ -361,9 +361,7 @@ def test_multi_target_transition_diagram():
361361
"""Edges are created for all targets of a multi-target transition."""
362362
from statemachine.contrib.diagram.model import DiagramTransition
363363

364-
transition = DiagramTransition(
365-
source="source", target="target1", targets=["target1", "target2"], event="go"
366-
)
364+
transition = DiagramTransition(source="source", targets=["target1", "target2"], event="go")
367365
renderer = DotRenderer()
368366
renderer._compound_ids = set()
369367
edges = renderer._create_edges(transition)

0 commit comments

Comments
 (0)