Skip to content

Commit 8c24a42

Browse files
committed
feat: support coroutine functions as invoke targets (#610)
Coroutine functions passed as invoke targets were silently broken: the coroutine was never awaited and the coroutine object was passed as data in the done.invoke event. Now the async engine detects coroutine callbacks and IInvoke handlers with async run() and awaits them directly on the event loop instead of routing through run_in_executor. Also adds InvalidDefinition when async IInvoke handlers are used with the sync engine, and ignores the local specs/ directory. Signed-off-by: Fernando Macedo <fgmacedo@gmail.com>
1 parent b7a46e5 commit 8c24a42

File tree

3 files changed

+214
-3
lines changed

3 files changed

+214
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,6 @@ docs/sg_execution_times.rst
8282

8383
# Temporary files
8484
tmp/
85+
86+
# Local specs
87+
specs/

statemachine/invoke.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from concurrent.futures import ThreadPoolExecutor
1414
from dataclasses import dataclass
1515
from dataclasses import field
16+
from inspect import iscoroutinefunction
1617
from typing import TYPE_CHECKING
1718
from typing import Any
1819
from typing import Callable
@@ -134,6 +135,16 @@ def _needs_wrapping(item: Any) -> bool:
134135
return False
135136

136137

138+
def _has_async_run(handler: Any) -> bool:
139+
"""Check if a handler (or its wrapped inner) has an async ``run()`` method."""
140+
if iscoroutinefunction(getattr(handler, "run", None)):
141+
return True
142+
if isinstance(handler, _InvokeCallableWrapper):
143+
inner = handler._invoke_handler
144+
return iscoroutinefunction(getattr(inner, "run", None))
145+
return False
146+
147+
137148
@dataclass
138149
class InvokeContext:
139150
"""Context passed to invoke handlers."""
@@ -357,6 +368,15 @@ def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs):
357368
# Use meta.func to find the original (unwrapped) handler; the callback
358369
# system wraps everything in a signature_adapter closure.
359370
handler = self._resolve_handler(callback.meta.func)
371+
372+
if handler is not None and _has_async_run(handler):
373+
from .exceptions import InvalidDefinition
374+
375+
raise InvalidDefinition(
376+
"Cannot use IInvoke with async run() on the sync engine. "
377+
"Add an async callback or listener to enable the async engine."
378+
)
379+
360380
ctx = self._make_context(state, event_kwargs, handler=handler)
361381
invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx)
362382

@@ -448,12 +468,22 @@ async def _run_async_handler(
448468
invocation: Invocation,
449469
):
450470
try:
451-
loop = asyncio.get_running_loop()
452-
if handler is not None:
453-
# Run handler.run(ctx) in a thread executor so blocking I/O
471+
if handler is not None and _has_async_run(handler):
472+
# Async IInvoke: call run() and await the coroutine directly
473+
# on the event loop (no executor needed).
474+
result = await handler.run(ctx)
475+
elif handler is not None:
476+
# Sync IInvoke: run in a thread executor so blocking I/O
454477
# doesn't freeze the event loop.
478+
loop = asyncio.get_running_loop()
455479
result = await loop.run_in_executor(None, handler.run, ctx)
480+
elif callback._iscoro:
481+
# Coroutine callback: await directly on the event loop.
482+
result = await callback(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
456483
else:
484+
# Sync callback: run in a thread executor so blocking I/O
485+
# doesn't freeze the event loop.
486+
loop = asyncio.get_running_loop()
457487
result = await loop.run_in_executor(
458488
None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
459489
)

tests/test_invoke.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,184 @@ def test_on_cancel_before_run(self):
829829
group.on_cancel()
830830

831831

832+
class TestCoroutineFunctionAsInvokeTarget:
833+
"""Coroutine functions should work as invoke targets on the async engine."""
834+
835+
async def test_coroutine_invoke_returns_awaited_result(self):
836+
"""An async function used as invoke target should be awaited and return its value."""
837+
from tests.conftest import SMRunner
838+
839+
async def async_loader():
840+
return 42
841+
842+
class SM(StateChart):
843+
loading = State(initial=True, invoke=async_loader)
844+
ready = State(final=True)
845+
done_invoke_loading = loading.to(ready)
846+
847+
def on_enter_ready(self, data=None, **kwargs):
848+
self.result = data
849+
850+
sm_runner = SMRunner(is_async=True)
851+
sm = await sm_runner.start(SM)
852+
await sm_runner.sleep(0.2)
853+
await sm_runner.processing_loop(sm)
854+
855+
assert "ready" in sm.configuration_values
856+
assert sm.result == 42
857+
858+
async def test_coroutine_invoke_error_sends_error_execution(self):
859+
"""An async invoke that raises should send error.execution."""
860+
from tests.conftest import SMRunner
861+
862+
async def failing_loader():
863+
raise ValueError("async boom")
864+
865+
class SM(StateChart):
866+
loading = State(initial=True, invoke=failing_loader)
867+
error_state = State(final=True)
868+
error_execution = loading.to(error_state)
869+
870+
def on_enter_error_state(self, error=None, **kwargs):
871+
self.caught_error = error
872+
873+
sm_runner = SMRunner(is_async=True)
874+
sm = await sm_runner.start(SM)
875+
await sm_runner.sleep(0.2)
876+
await sm_runner.processing_loop(sm)
877+
878+
assert "error_state" in sm.configuration_values
879+
assert isinstance(sm.caught_error, ValueError)
880+
assert str(sm.caught_error) == "async boom"
881+
882+
async def test_coroutine_invoke_cancelled_on_state_exit(self):
883+
"""An async invoke should be cancelled when the owning state is exited."""
884+
from tests.conftest import SMRunner
885+
886+
cancel_observed = []
887+
888+
async def slow_loader():
889+
import asyncio
890+
891+
try:
892+
await asyncio.sleep(10)
893+
except asyncio.CancelledError:
894+
cancel_observed.append(True)
895+
raise
896+
return "should not reach"
897+
898+
class SM(StateChart):
899+
loading = State(initial=True, invoke=slow_loader)
900+
stopped = State(final=True)
901+
cancel = loading.to(stopped)
902+
903+
sm_runner = SMRunner(is_async=True)
904+
sm = await sm_runner.start(SM)
905+
await sm_runner.sleep(0.05)
906+
await sm_runner.send(sm, "cancel")
907+
await sm_runner.sleep(0.05)
908+
909+
assert "stopped" in sm.configuration_values
910+
911+
912+
class TestAsyncIInvokeInstance:
913+
"""IInvoke instances with async def run() should be awaited on the async engine."""
914+
915+
async def test_async_iinvoke_instance(self):
916+
"""An IInvoke instance with async run() should be awaited."""
917+
from tests.conftest import SMRunner
918+
919+
class AsyncHandler:
920+
async def run(self, ctx):
921+
return "async_result"
922+
923+
handler = AsyncHandler()
924+
925+
class SM(StateChart):
926+
loading = State(initial=True, invoke=handler)
927+
ready = State(final=True)
928+
done_invoke_loading = loading.to(ready)
929+
930+
def on_enter_ready(self, data=None, **kwargs):
931+
self.result = data
932+
933+
sm_runner = SMRunner(is_async=True)
934+
sm = await sm_runner.start(SM)
935+
await sm_runner.sleep(0.2)
936+
await sm_runner.processing_loop(sm)
937+
938+
assert "ready" in sm.configuration_values
939+
assert sm.result == "async_result"
940+
941+
942+
class TestAsyncIInvokeClass:
943+
"""IInvoke classes with async def run() should be instantiated and awaited."""
944+
945+
async def test_async_iinvoke_class(self):
946+
"""An IInvoke class with async run() should be instantiated and its run() awaited."""
947+
from tests.conftest import SMRunner
948+
949+
class AsyncHandler:
950+
async def run(self, ctx):
951+
return "class_async_result"
952+
953+
class SM(StateChart):
954+
loading = State(initial=True, invoke=AsyncHandler)
955+
ready = State(final=True)
956+
done_invoke_loading = loading.to(ready)
957+
958+
def on_enter_ready(self, data=None, **kwargs):
959+
self.result = data
960+
961+
sm_runner = SMRunner(is_async=True)
962+
sm = await sm_runner.start(SM)
963+
await sm_runner.sleep(0.2)
964+
await sm_runner.processing_loop(sm)
965+
966+
assert "ready" in sm.configuration_values
967+
assert sm.result == "class_async_result"
968+
969+
970+
class TestAsyncIInvokeOnSyncEngine:
971+
"""IInvoke with async run() on the sync engine should raise InvalidDefinition."""
972+
973+
def test_async_iinvoke_instance_on_sync_engine_raises(self):
974+
"""An IInvoke instance with async run() should fail clearly on the sync engine."""
975+
import pytest
976+
from statemachine.exceptions import InvalidDefinition
977+
978+
class AsyncHandler:
979+
async def run(self, ctx):
980+
return "unreachable"
981+
982+
handler = AsyncHandler()
983+
984+
class SM(StateChart):
985+
loading = State(initial=True, invoke=handler)
986+
ready = State(final=True)
987+
done_invoke_loading = loading.to(ready)
988+
989+
with pytest.raises(InvalidDefinition):
990+
SM()
991+
992+
def test_async_iinvoke_class_on_sync_engine_raises(self):
993+
"""An IInvoke class with async run() should fail clearly on the sync engine."""
994+
import pytest
995+
from statemachine.exceptions import InvalidDefinition
996+
997+
class AsyncHandler:
998+
async def run(self, ctx):
999+
return "unreachable"
1000+
1001+
class SM(StateChart):
1002+
loading = State(initial=True, invoke=AsyncHandler)
1003+
ready = State(final=True)
1004+
done_invoke_loading = loading.to(ready)
1005+
1006+
with pytest.raises(InvalidDefinition):
1007+
SM()
1008+
1009+
8321010
class TestDoneInvokeEventFactory:
8331011
"""done_invoke_ prefix works with both TransitionList and Event."""
8341012

0 commit comments

Comments
 (0)