@@ -18,62 +18,79 @@ class DotGraphMachine:
1818 font_name = "Arial"
1919 """Graph font face name"""
2020
21- state_font_size = "10 "
22- """State font size in points """
21+ state_font_size = "10pt "
22+ """State font size"""
2323
2424 state_active_penwidth = 2
2525 """Active state external line width"""
2626
2727 state_active_fillcolor = "turquoise"
2828
29- transition_font_size = "9 "
30- """Transition font size in points """
29+ transition_font_size = "9pt "
30+ """Transition font size"""
3131
32- def __init__ (self , machine : StateMachine ):
32+ def __init__ (self , machine ):
3333 self .machine = machine
3434
35- def _get_graph (self ):
36- machine = self .machine
35+ def _get_graph (self , machine ):
3736 return pydot .Dot (
38- "list" ,
37+ machine . name ,
3938 graph_type = "digraph" ,
4039 label = machine .name ,
4140 fontname = self .font_name ,
4241 fontsize = self .state_font_size ,
4342 rankdir = self .graph_rankdir ,
43+ compound = "true" ,
4444 )
4545
46- def _initial_node (self ):
46+ def _get_subgraph (self , state ):
47+ style = ", solid"
48+ if state .parent and state .parent .parallel :
49+ style = ", dashed"
50+ subgraph = pydot .Subgraph (
51+ label = f"{ state .name } " ,
52+ graph_name = f"cluster_{ state .id } " ,
53+ style = f"rounded{ style } " ,
54+ cluster = "true" ,
55+ )
56+ return subgraph
57+
58+ def _initial_node (self , state ):
4759 node = pydot .Node (
48- "i" ,
49- shape = "circle" ,
60+ self ._state_id (state ),
61+ label = "" ,
62+ shape = "point" ,
5063 style = "filled" ,
51- fontsize = "1 " ,
64+ fontsize = "1pt " ,
5265 fixedsize = "true" ,
5366 width = 0.2 ,
5467 height = 0.2 ,
5568 )
5669 node .set_fillcolor ("black" )
5770 return node
5871
59- def _initial_edge (self ):
72+ def _initial_edge (self , initial_node , state ):
73+ extra_params = {}
74+ if state .states :
75+ extra_params ["lhead" ] = f"cluster_{ state .id } "
6076 return pydot .Edge (
61- "i" ,
62- self .machine . initial_state . id ,
77+ initial_node . get_name () ,
78+ self ._state_id ( state ) ,
6379 label = "" ,
6480 color = "blue" ,
6581 fontname = self .font_name ,
6682 fontsize = self .transition_font_size ,
83+ ** extra_params ,
6784 )
6885
6986 def _actions_getter (self ):
7087 if isinstance (self .machine , StateMachine ):
7188
72- def getter (grouper ) -> str :
89+ def getter (grouper ):
7390 return self .machine ._callbacks .str (grouper .key )
7491 else :
7592
76- def getter (grouper ) -> str :
93+ def getter (grouper ):
7794 all_names = set (dir (self .machine ))
7895 return ", " .join (
7996 str (c ) for c in grouper if not c .is_convention or c .func in all_names
@@ -104,11 +121,18 @@ def _state_actions(self, state):
104121
105122 return actions
106123
124+ @staticmethod
125+ def _state_id (state ):
126+ if state .states :
127+ return f"{ state .id } _anchor"
128+ else :
129+ return state .id
130+
107131 def _state_as_node (self , state ):
108132 actions = self ._state_actions (state )
109133
110134 node = pydot .Node (
111- state . id ,
135+ self . _state_id ( state ) ,
112136 label = f"{ state .name } { actions } " ,
113137 shape = "rectangle" ,
114138 style = "rounded, filled" ,
@@ -127,29 +151,64 @@ def _transition_as_edge(self, transition):
127151 cond = ", " .join ([str (cond ) for cond in transition .cond ])
128152 if cond :
129153 cond = f"\n [{ cond } ]"
154+
155+ extra_params = {}
156+ has_substates = transition .source .states or transition .target .states
157+ if transition .source .states :
158+ extra_params ["ltail" ] = f"cluster_{ transition .source .id } "
159+ if transition .target .states :
160+ extra_params ["lhead" ] = f"cluster_{ transition .target .id } "
161+
130162 return pydot .Edge (
131- transition .source . id ,
132- transition .target . id ,
163+ self . _state_id ( transition .source ) ,
164+ self . _state_id ( transition .target ) ,
133165 label = f"{ transition .event } { cond } " ,
134166 color = "blue" ,
135167 fontname = self .font_name ,
136168 fontsize = self .transition_font_size ,
169+ minlen = 2 if has_substates else 1 ,
170+ ** extra_params ,
137171 )
138172
139173 def get_graph (self ):
140- graph = self ._get_graph ()
141- graph . add_node (self ._initial_node () )
142- graph . add_edge ( self . _initial_edge ())
174+ graph = self ._get_graph (self . machine )
175+ self . _graph_states (self .machine , graph )
176+ return graph
143177
144- for state in self .machine .states :
145- graph .add_node (self ._state_as_node (state ))
146- for transition in state .transitions :
178+ def _graph_states (self , state , graph ):
179+ initial_node = self ._initial_node (state )
180+ initial_subgraph = pydot .Subgraph (
181+ graph_name = f"{ initial_node .get_name ()} _initial" ,
182+ label = "" ,
183+ peripheries = 0 ,
184+ margin = 0 ,
185+ )
186+ atomic_states_subgraph = pydot .Subgraph (
187+ graph_name = f"cluster_{ initial_node .get_name ()} _atomic" ,
188+ label = "" ,
189+ peripheries = 0 ,
190+ cluster = "true" ,
191+ )
192+ initial_subgraph .add_node (initial_node )
193+ graph .add_subgraph (initial_subgraph )
194+ graph .add_subgraph (atomic_states_subgraph )
195+
196+ initial = next (s for s in state .states if s .initial )
197+ graph .add_edge (self ._initial_edge (initial_node , initial ))
198+
199+ for substate in state .states :
200+ if substate .states :
201+ subgraph = self ._get_subgraph (substate )
202+ self ._graph_states (substate , subgraph )
203+ graph .add_subgraph (subgraph )
204+ else :
205+ atomic_states_subgraph .add_node (self ._state_as_node (substate ))
206+
207+ for transition in substate .transitions :
147208 if transition .internal :
148209 continue
149210 graph .add_edge (self ._transition_as_edge (transition ))
150211
151- return graph
152-
153212 def __call__ (self ):
154213 return self .get_graph ()
155214
@@ -165,7 +224,11 @@ def quickchart_write_svg(sm: StateMachine, path: str):
165224 >>> from tests.examples.order_control_machine import OrderControl
166225 >>> sm = OrderControl()
167226 >>> print(sm._graph().to_string())
168- digraph list {
227+ digraph OrderControl {
228+ label=OrderControl;
229+ fontname=Arial;
230+ fontsize="10pt";
231+ rankdir=LR;
169232 ...
170233
171234 To give you an example, we included this method that will serialize the dot, request the graph
0 commit comments