@@ -28,7 +28,9 @@ def __init__(self, model=None, state_field="state", start_value=None):
2828 self .state_field = state_field
2929 self .start_value = start_value
3030
31- initial_transition = Transition (None , None , event = "__initial__" )
31+ initial_transition = Transition (
32+ None , self ._get_initial_state (), event = "__initial__"
33+ )
3234 self ._setup (initial_transition )
3335 self ._activate_initial_state (initial_transition )
3436
@@ -39,21 +41,19 @@ def __repr__(self):
3941 f"current_state={ current_state_id !r} )"
4042 )
4143
42- def _activate_initial_state (self , initial_transition ):
43-
44+ def _get_initial_state (self ):
4445 current_state_value = (
4546 self .start_value if self .start_value else self .initial_state .value
4647 )
47- if self .current_state_value is None :
48-
49- try :
50- initial_state = self .states_map [current_state_value ]
51- except KeyError as err :
52- raise InvalidStateValue (current_state_value ) from err
48+ try :
49+ return self .states_map [current_state_value ]
50+ except KeyError as err :
51+ raise InvalidStateValue (current_state_value ) from err
5352
53+ def _activate_initial_state (self , initial_transition ):
54+ if self .current_state_value is None :
5455 # send an one-time event `__initial__` to enter the current state.
5556 # current_state = self.current_state
56- initial_transition .target = initial_state
5757 initial_transition .before .clear ()
5858 initial_transition .on .clear ()
5959 initial_transition .after .clear ()
@@ -93,8 +93,18 @@ def _setup(self, initial_transition):
9393 model = ObjectConfig (self .model , skip_attrs = {self .state_field })
9494 default_resolver = resolver_factory (machine , model )
9595
96- initial_transition ._setup (default_resolver )
97- self ._visit_states_and_transitions (lambda x : x ._setup (default_resolver ))
96+ # clone states and transitions to avoid sharing callbacks references between instances
97+ self .states_map = {
98+ state .value : state .clone ()._setup (self , default_resolver )
99+ for state in self .states
100+ }
101+ self .states = list (self .states_map .values ())
102+
103+ for state in self .states :
104+ for transition in state .transitions :
105+ transition ._setup (self , default_resolver )
106+
107+ initial_transition ._setup (self , default_resolver )
98108 self .add_observer (machine , model )
99109
100110 def add_observer (self , * observers ):
0 commit comments