Skip to content

Commit 6f2b617

Browse files
authored
feat: propagate constructor kwargs to initial state callbacks (#572)
Forward **kwargs from StateChart.__init__() through the engine's initial event (TriggerData), making them available to on_enter_<initial_state>, invoke handlers, and other initial-entry callbacks via dependency injection. This enables self-contained machines to receive context at creation time, e.g. `MyMachine(url="...", config=config)`.
1 parent f1cbfbb commit 6f2b617

7 files changed

Lines changed: 140 additions & 9 deletions

File tree

docs/invoke.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,37 @@ True
302302

303303
```
304304

305-
For initial states (entered automatically, not via an event), `kwargs` is empty.
305+
For initial states, any extra keyword arguments passed to the `StateChart` constructor
306+
are forwarded as event data. This makes self-contained machines that start processing
307+
immediately especially useful:
308+
309+
```py
310+
>>> config_file = Path(tempfile.mktemp(suffix=".json"))
311+
>>> _ = config_file.write_text('{"theme": "dark"}')
312+
313+
>>> class AppLoader(StateChart):
314+
... loading = State(initial=True)
315+
... ready = State(final=True)
316+
... done_invoke_loading = loading.to(ready)
317+
...
318+
... def on_invoke_loading(self, config_path=None, **kwargs):
319+
... """config_path comes from the constructor: AppLoader(config_path=...)."""
320+
... return json.loads(Path(config_path).read_text())
321+
...
322+
... def on_enter_ready(self, data=None, **kwargs):
323+
... self.config = data
324+
325+
>>> sm = AppLoader(config_path=str(config_file))
326+
>>> time.sleep(0.2)
327+
328+
>>> "ready" in sm.configuration_values
329+
True
330+
>>> sm.config
331+
{'theme': 'dark'}
332+
333+
>>> config_file.unlink()
334+
335+
```
306336

307337
## Error handling
308338

docs/releases/3.0.0.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,24 @@ and wait for all results:
6565

6666
```
6767

68+
Constructor keyword arguments are forwarded to initial state callbacks, so self-contained
69+
machines can receive context at creation time:
70+
71+
```py
72+
>>> class Greeter(StateChart):
73+
... idle = State(initial=True)
74+
... done = State(final=True)
75+
... idle.to(done)
76+
...
77+
... def on_enter_idle(self, name=None, **kwargs):
78+
... self.greeting = f"Hello, {name}!"
79+
80+
>>> sm = Greeter(name="Alice")
81+
>>> sm.greeting
82+
'Hello, Alice!'
83+
84+
```
85+
6886
See {ref}`invoke` for full documentation.
6987

7088
### Compound states

statemachine/engines/async_.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,15 @@ async def _run_microstep(self, enabled_transitions, trigger_data): # pragma: no
309309
except Exception as e:
310310
self._handle_error(e, trigger_data)
311311

312-
async def activate_initial_state(self):
312+
async def activate_initial_state(self, **kwargs):
313313
"""Activate the initial state.
314314
315315
In async code, the user must call this method explicitly (or it will be lazily
316316
activated on the first event). There's no built-in way to call async code from
317317
``StateMachine.__init__``.
318+
319+
Any ``**kwargs`` are forwarded to initial state entry callbacks via dependency
320+
injection, just like event kwargs on ``send()``.
318321
"""
319322
return await self.processing_loop()
320323

statemachine/engines/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ def _send_error_execution(self, error: Exception, trigger_data: TriggerData):
181181
return
182182
self.sm.send(_ERROR_EXECUTION, error=error, internal=True)
183183

184-
def start(self):
184+
def start(self, **kwargs):
185185
if self.sm.current_state_value is not None:
186186
return
187187

188-
BoundEvent("__initial__", _sm=self.sm).put()
188+
BoundEvent("__initial__", _sm=self.sm).put(**kwargs)
189189

190190
def _initial_transitions(self, trigger_data):
191191
empty_state = State()

statemachine/engines/sync.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def _run_microstep(self, enabled_transitions, trigger_data):
3131
except Exception as e: # pragma: no cover
3232
self._handle_error(e, trigger_data)
3333

34-
def start(self):
34+
def start(self, **kwargs):
3535
if self.sm.current_state_value is not None:
3636
return
3737

38-
self.activate_initial_state()
38+
self.activate_initial_state(**kwargs)
3939

40-
def activate_initial_state(self):
40+
def activate_initial_state(self, **kwargs):
4141
"""
4242
Activate the initial state.
4343
@@ -48,7 +48,9 @@ def activate_initial_state(self):
4848
may depend on async code from the StateMachine.__init__ method.
4949
"""
5050
if self.sm.current_state_value is None:
51-
trigger_data = BoundEvent("__initial__", _sm=self.sm).build_trigger(machine=self.sm)
51+
trigger_data = BoundEvent("__initial__", _sm=self.sm).build_trigger(
52+
machine=self.sm, **kwargs
53+
)
5254
transitions = self._initial_transitions(trigger_data)
5355
self._processing.acquire(blocking=False)
5456
try:

statemachine/statemachine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(
165165
# for async code, the user should manually call `await sm.activate_initial_state()`
166166
# after state machine creation.
167167
self._engine = self._get_engine()
168-
self._engine.start()
168+
self._engine.start(**kwargs)
169169

170170
def _get_engine(self):
171171
if self._callbacks.has_async_callbacks:

tests/test_statemachine.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,81 @@ class SM(StateChart):
729729
warnings.simplefilter("ignore", DeprecationWarning)
730730
with pytest.raises(exceptions.InvalidStateValue):
731731
_ = sm.current_state
732+
733+
734+
class TestInitKwargsPropagation:
735+
"""Constructor kwargs are forwarded to initial state entry callbacks."""
736+
737+
async def test_kwargs_available_in_on_enter_initial(self, sm_runner):
738+
class SM(StateChart):
739+
idle = State(initial=True)
740+
done = State(final=True)
741+
go = idle.to(done)
742+
743+
def on_enter_idle(self, greeting=None, **kwargs):
744+
self.greeting = greeting
745+
746+
sm = await sm_runner.start(SM, greeting="hello")
747+
assert sm.greeting == "hello"
748+
749+
async def test_kwargs_flow_through_eventless_transitions(self, sm_runner):
750+
class Pipeline(StateChart):
751+
start = State(initial=True)
752+
processing = State()
753+
done = State(final=True)
754+
755+
start.to(processing)
756+
processing.to(done)
757+
758+
def on_enter_start(self, task_id=None, **kwargs):
759+
self.task_id = task_id
760+
761+
sm = await sm_runner.start(Pipeline, task_id="abc-123")
762+
assert sm.task_id == "abc-123"
763+
assert "done" in sm.configuration_values
764+
765+
async def test_no_kwargs_still_works(self, sm_runner):
766+
class SM(StateChart):
767+
idle = State(initial=True)
768+
done = State(final=True)
769+
go = idle.to(done)
770+
771+
def on_enter_idle(self, **kwargs):
772+
self.entered = True
773+
774+
sm = await sm_runner.start(SM)
775+
assert sm.entered is True
776+
777+
async def test_multiple_kwargs(self, sm_runner):
778+
class SM(StateChart):
779+
idle = State(initial=True)
780+
done = State(final=True)
781+
go = idle.to(done)
782+
783+
def on_enter_idle(self, host=None, port=None, **kwargs):
784+
self.host = host
785+
self.port = port
786+
787+
sm = await sm_runner.start(SM, host="localhost", port=5432)
788+
assert sm.host == "localhost"
789+
assert sm.port == 5432
790+
791+
async def test_kwargs_in_invoke_handler(self, sm_runner):
792+
"""Init kwargs flow to invoke handlers via dependency injection."""
793+
794+
class SM(StateChart):
795+
loading = State(initial=True)
796+
ready = State(final=True)
797+
done_invoke_loading = loading.to(ready)
798+
799+
def on_invoke_loading(self, url=None, **kwargs):
800+
return f"fetched:{url}"
801+
802+
def on_enter_ready(self, data=None, **kwargs):
803+
self.result = data
804+
805+
sm = await sm_runner.start(SM, url="https://example.com")
806+
await sm_runner.sleep(0.2)
807+
await sm_runner.processing_loop(sm)
808+
assert "ready" in sm.configuration_values
809+
assert sm.result == "fetched:https://example.com"

0 commit comments

Comments
 (0)