Skip to content

Commit f54d9f5

Browse files
authored
feat: support coroutine functions as invoke targets (#610) (#611)
1 parent b7a46e5 commit f54d9f5

File tree

3 files changed

+214
-3
lines changed

3 files changed

+214
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,6 @@ docs/sg_execution_times.rst
8282

8383
# Temporary files
8484
tmp/
85+
86+
# Local specs
87+
specs/

statemachine/invoke.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from concurrent.futures import ThreadPoolExecutor
1414
from dataclasses import dataclass
1515
from dataclasses import field
16+
from inspect import iscoroutinefunction
1617
from typing import TYPE_CHECKING
1718
from typing import Any
1819
from typing import Callable
@@ -134,6 +135,16 @@ def _needs_wrapping(item: Any) -> bool:
134135
return False
135136

136137

138+
def _has_async_run(handler: Any) -> bool:
139+
"""Check if a handler (or its wrapped inner) has an async ``run()`` method."""
140+
if iscoroutinefunction(getattr(handler, "run", None)):
141+
return True
142+
if isinstance(handler, _InvokeCallableWrapper):
143+
inner = handler._invoke_handler
144+
return iscoroutinefunction(getattr(inner, "run", None))
145+
return False
146+
147+
137148
@dataclass
138149
class InvokeContext:
139150
"""Context passed to invoke handlers."""
@@ -357,6 +368,15 @@ def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs):
357368
# Use meta.func to find the original (unwrapped) handler; the callback
358369
# system wraps everything in a signature_adapter closure.
359370
handler = self._resolve_handler(callback.meta.func)
371+
372+
if handler is not None and _has_async_run(handler):
373+
from .exceptions import InvalidDefinition
374+
375+
raise InvalidDefinition(
376+
"Cannot use IInvoke with async run() on the sync engine. "
377+
"Add an async callback or listener to enable the async engine."
378+
)
379+
360380
ctx = self._make_context(state, event_kwargs, handler=handler)
361381
invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx)
362382

@@ -448,12 +468,22 @@ async def _run_async_handler(
448468
invocation: Invocation,
449469
):
450470
try:
451-
loop = asyncio.get_running_loop()
452-
if handler is not None:
453-
# Run handler.run(ctx) in a thread executor so blocking I/O
471+
if handler is not None and _has_async_run(handler):
472+
# Async IInvoke: call run() and await the coroutine directly
473+
# on the event loop (no executor needed).
474+
result = await handler.run(ctx)
475+
elif handler is not None:
476+
# Sync IInvoke: run in a thread executor so blocking I/O
454477
# doesn't freeze the event loop.
478+
loop = asyncio.get_running_loop()
455479
result = await loop.run_in_executor(None, handler.run, ctx)
480+
elif callback._iscoro:
481+
# Coroutine callback: await directly on the event loop.
482+
result = await callback(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
456483
else:
484+
# Sync callback: run in a thread executor so blocking I/O
485+
# doesn't freeze the event loop.
486+
loop = asyncio.get_running_loop()
457487
result = await loop.run_in_executor(
458488
None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
459489
)

tests/test_invoke.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,184 @@ def test_on_cancel_before_run(self):
829829
group.on_cancel()
830830

831831

832+
class TestCoroutineFunctionAsInvokeTarget:
833+
"""Coroutine functions should work as invoke targets on the async engine."""
834+
835+
async def test_coroutine_invoke_returns_awaited_result(self):
836+
"""An async function used as invoke target should be awaited and return its value."""
837+
from tests.conftest import SMRunner
838+
839+
async def async_loader():
840+
return 42
841+
842+
class SM(StateChart):
843+
loading = State(initial=True, invoke=async_loader)
844+
ready = State(final=True)
845+
done_invoke_loading = loading.to(ready)
846+
847+
def on_enter_ready(self, data=None, **kwargs):
848+
self.result = data
849+
850+
sm_runner = SMRunner(is_async=True)
851+
sm = await sm_runner.start(SM)
852+
await sm_runner.sleep(0.2)
853+
await sm_runner.processing_loop(sm)
854+
855+
assert "ready" in sm.configuration_values
856+
assert sm.result == 42
857+
858+
async def test_coroutine_invoke_error_sends_error_execution(self):
859+
"""An async invoke that raises should send error.execution."""
860+
from tests.conftest import SMRunner
861+
862+
async def failing_loader():
863+
raise ValueError("async boom")
864+
865+
class SM(StateChart):
866+
loading = State(initial=True, invoke=failing_loader)
867+
error_state = State(final=True)
868+
error_execution = loading.to(error_state)
869+
870+
def on_enter_error_state(self, error=None, **kwargs):
871+
self.caught_error = error
872+
873+
sm_runner = SMRunner(is_async=True)
874+
sm = await sm_runner.start(SM)
875+
await sm_runner.sleep(0.2)
876+
await sm_runner.processing_loop(sm)
877+
878+
assert "error_state" in sm.configuration_values
879+
assert isinstance(sm.caught_error, ValueError)
880+
assert str(sm.caught_error) == "async boom"
881+
882+
async def test_coroutine_invoke_cancelled_on_state_exit(self):
883+
"""An async invoke should be cancelled when the owning state is exited."""
884+
from tests.conftest import SMRunner
885+
886+
cancel_observed = []
887+
888+
async def slow_loader():
889+
import asyncio
890+
891+
try:
892+
await asyncio.sleep(10)
893+
except asyncio.CancelledError:
894+
cancel_observed.append(True)
895+
raise
896+
return "should not reach"
897+
898+
class SM(StateChart):
899+
loading = State(initial=True, invoke=slow_loader)
900+
stopped = State(final=True)
901+
cancel = loading.to(stopped)
902+
903+
sm_runner = SMRunner(is_async=True)
904+
sm = await sm_runner.start(SM)
905+
await sm_runner.sleep(0.05)
906+
await sm_runner.send(sm, "cancel")
907+
await sm_runner.sleep(0.05)
908+
909+
assert "stopped" in sm.configuration_values
910+
911+
912+
class TestAsyncIInvokeInstance:
913+
"""IInvoke instances with async def run() should be awaited on the async engine."""
914+
915+
async def test_async_iinvoke_instance(self):
916+
"""An IInvoke instance with async run() should be awaited."""
917+
from tests.conftest import SMRunner
918+
919+
class AsyncHandler:
920+
async def run(self, ctx):
921+
return "async_result"
922+
923+
handler = AsyncHandler()
924+
925+
class SM(StateChart):
926+
loading = State(initial=True, invoke=handler)
927+
ready = State(final=True)
928+
done_invoke_loading = loading.to(ready)
929+
930+
def on_enter_ready(self, data=None, **kwargs):
931+
self.result = data
932+
933+
sm_runner = SMRunner(is_async=True)
934+
sm = await sm_runner.start(SM)
935+
await sm_runner.sleep(0.2)
936+
await sm_runner.processing_loop(sm)
937+
938+
assert "ready" in sm.configuration_values
939+
assert sm.result == "async_result"
940+
941+
942+
class TestAsyncIInvokeClass:
943+
"""IInvoke classes with async def run() should be instantiated and awaited."""
944+
945+
async def test_async_iinvoke_class(self):
946+
"""An IInvoke class with async run() should be instantiated and its run() awaited."""
947+
from tests.conftest import SMRunner
948+
949+
class AsyncHandler:
950+
async def run(self, ctx):
951+
return "class_async_result"
952+
953+
class SM(StateChart):
954+
loading = State(initial=True, invoke=AsyncHandler)
955+
ready = State(final=True)
956+
done_invoke_loading = loading.to(ready)
957+
958+
def on_enter_ready(self, data=None, **kwargs):
959+
self.result = data
960+
961+
sm_runner = SMRunner(is_async=True)
962+
sm = await sm_runner.start(SM)
963+
await sm_runner.sleep(0.2)
964+
await sm_runner.processing_loop(sm)
965+
966+
assert "ready" in sm.configuration_values
967+
assert sm.result == "class_async_result"
968+
969+
970+
class TestAsyncIInvokeOnSyncEngine:
971+
"""IInvoke with async run() on the sync engine should raise InvalidDefinition."""
972+
973+
def test_async_iinvoke_instance_on_sync_engine_raises(self):
974+
"""An IInvoke instance with async run() should fail clearly on the sync engine."""
975+
import pytest
976+
from statemachine.exceptions import InvalidDefinition
977+
978+
class AsyncHandler:
979+
async def run(self, ctx):
980+
return "unreachable"
981+
982+
handler = AsyncHandler()
983+
984+
class SM(StateChart):
985+
loading = State(initial=True, invoke=handler)
986+
ready = State(final=True)
987+
done_invoke_loading = loading.to(ready)
988+
989+
with pytest.raises(InvalidDefinition):
990+
SM()
991+
992+
def test_async_iinvoke_class_on_sync_engine_raises(self):
993+
"""An IInvoke class with async run() should fail clearly on the sync engine."""
994+
import pytest
995+
from statemachine.exceptions import InvalidDefinition
996+
997+
class AsyncHandler:
998+
async def run(self, ctx):
999+
return "unreachable"
1000+
1001+
class SM(StateChart):
1002+
loading = State(initial=True, invoke=AsyncHandler)
1003+
ready = State(final=True)
1004+
done_invoke_loading = loading.to(ready)
1005+
1006+
with pytest.raises(InvalidDefinition):
1007+
SM()
1008+
1009+
8321010
class TestDoneInvokeEventFactory:
8331011
"""done_invoke_ prefix works with both TransitionList and Event."""
8341012

0 commit comments

Comments
 (0)