Skip to content

Commit 570687e

Browse files
authored
fix: await async predicates in condition expressions (#557)
* fix: re-enqueue initial event when deserializing async state machine (#544) When an async SM is pickled/deepcopied (e.g. via multiprocessing), the engine queue is not preserved. __setstate__ recreated the engine but never called start(), so the __initial__ event was never enqueued and activate_initial_state() would fail with InvalidStateValue. Closes #544 * fix: await async predicates in condition expressions (#535) The boolean expression combinators (custom_not, custom_and, custom_or, build_custom_operator) called predicates synchronously. When predicates were async, they returned unawaited coroutine objects which are always truthy, causing `not` to always return False, `and` to skip evaluation, and `or` to short-circuit incorrectly. Each combinator now checks `isawaitable()` on predicate results and returns a coroutine when needed, which CallbackWrapper.__call__ already knows how to await. Closes #535 * chore: sync pre-commit ruff rev with lockfile (v0.15.0) The pre-commit hook was using ruff v0.8.1 while the lockfile had v0.15.0, causing import sorting differences between local and CI. * fix: address SonarCloud code smells in tests - Add docstrings to empty async on_enter_state methods (S1186) - Use await asyncio.sleep(0) in async test hooks to satisfy S7503
1 parent 37e6c1a commit 570687e

8 files changed

Lines changed: 290 additions & 11 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ repos:
99
exclude: docs/auto_examples
1010
- repo: https://github.com/charliermarsh/ruff-pre-commit
1111
# Ruff version.
12-
rev: v0.8.1
12+
rev: v0.15.0
1313
hooks:
1414
# Run the linter.
1515
- id: ruff

statemachine/spec_parser.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
import re
44
from functools import reduce
5+
from inspect import isawaitable
56
from typing import Callable
67

78
replacements = {"!": "not ", "^": " and ", "v": " or "}
@@ -33,8 +34,15 @@ def match_func(match):
3334

3435

3536
def custom_not(predicate: Callable) -> Callable:
36-
def decorated(*args, **kwargs) -> bool:
37-
return not predicate(*args, **kwargs)
37+
def decorated(*args, **kwargs):
38+
result = predicate(*args, **kwargs)
39+
if isawaitable(result):
40+
41+
async def _negate():
42+
return not await result
43+
44+
return _negate()
45+
return not result
3846

3947
decorated.__name__ = f"not({predicate.__name__})"
4048
unique_key = getattr(predicate, "unique_key", "")
@@ -43,17 +51,53 @@ def decorated(*args, **kwargs) -> bool:
4351

4452

4553
def custom_and(left: Callable, right: Callable) -> Callable:
46-
def decorated(*args, **kwargs) -> bool:
47-
return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return]
54+
def decorated(*args, **kwargs):
55+
left_result = left(*args, **kwargs)
56+
if isawaitable(left_result):
57+
58+
async def _async_and():
59+
lr = await left_result
60+
if not lr:
61+
return lr
62+
rr = right(*args, **kwargs)
63+
if isawaitable(rr):
64+
return await rr
65+
return rr
66+
67+
return _async_and()
68+
if not left_result:
69+
return left_result
70+
right_result = right(*args, **kwargs)
71+
if isawaitable(right_result):
72+
return right_result
73+
return right_result
4874

4975
decorated.__name__ = f"({left.__name__} and {right.__name__})"
5076
decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined]
5177
return decorated
5278

5379

5480
def custom_or(left: Callable, right: Callable) -> Callable:
55-
def decorated(*args, **kwargs) -> bool:
56-
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
81+
def decorated(*args, **kwargs):
82+
left_result = left(*args, **kwargs)
83+
if isawaitable(left_result):
84+
85+
async def _async_or():
86+
lr = await left_result
87+
if lr:
88+
return lr
89+
rr = right(*args, **kwargs)
90+
if isawaitable(rr):
91+
return await rr
92+
return rr
93+
94+
return _async_or()
95+
if left_result:
96+
return left_result
97+
right_result = right(*args, **kwargs)
98+
if isawaitable(right_result):
99+
return right_result
100+
return right_result
57101

58102
decorated.__name__ = f"({left.__name__} or {right.__name__})"
59103
decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined]
@@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable:
73117
operator_repr = comparison_repr[operator]
74118

75119
def custom_comparator(left: Callable, right: Callable) -> Callable:
76-
def decorated(*args, **kwargs) -> bool:
77-
return bool(operator(left(*args, **kwargs), right(*args, **kwargs)))
120+
def decorated(*args, **kwargs):
121+
left_result = left(*args, **kwargs)
122+
right_result = right(*args, **kwargs)
123+
if isawaitable(left_result) or isawaitable(right_result):
124+
125+
async def _async_compare():
126+
lr = (await left_result) if isawaitable(left_result) else left_result
127+
rr = (await right_result) if isawaitable(right_result) else right_result
128+
return bool(operator(lr, rr))
129+
130+
return _async_compare()
131+
return bool(operator(left_result, right_result))
78132

79133
decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})"
80134
decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined]

statemachine/statemachine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __setstate__(self, state):
147147
self._register_callbacks([])
148148
self.add_listener(*listeners.keys())
149149
self._engine = self._get_engine(rtc)
150+
self._engine.start()
150151

151152
def _get_initial_state(self):
152153
initial_state_value = self.start_value if self.start_value else self.initial_state.value

tests/examples/user_machine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class UserStatusMachine(StateMachine):
8888
def on_signup(self, token: str):
8989
if token == "":
9090
raise ValueError("Token is required")
91-
self.model.verified = True
91+
self.model.verified = True # type: ignore[union-attr]
9292

9393

9494
class UserExperienceMachine(StateMachine):

tests/test_async.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,86 @@ def test_async_state_from_sync_context(async_order_control_machine):
9696
assert sm.completed.is_active
9797

9898

99+
class AsyncConditionExpressionMachine(StateMachine):
100+
"""Regression test for issue #535: async conditions in boolean expressions."""
101+
102+
s1 = State(initial=True)
103+
104+
go_not = s1.to.itself(cond="not cond_false")
105+
go_and = s1.to.itself(cond="cond_true and cond_true")
106+
go_or_false_first = s1.to.itself(cond="cond_false or cond_true")
107+
go_or_true_first = s1.to.itself(cond="cond_true or cond_false")
108+
go_blocked = s1.to.itself(cond="not cond_true")
109+
go_and_blocked = s1.to.itself(cond="cond_true and cond_false")
110+
go_or_both_false = s1.to.itself(cond="cond_false or cond_false")
111+
112+
async def cond_true(self):
113+
return True
114+
115+
async def cond_false(self):
116+
return False
117+
118+
async def on_enter_state(self, target):
119+
"""Async callback to ensure the SM uses AsyncEngine."""
120+
121+
122+
async def test_async_condition_not(recwarn):
123+
"""Issue #535: 'not cond_false' should allow the transition."""
124+
sm = AsyncConditionExpressionMachine()
125+
await sm.activate_initial_state()
126+
await sm.go_not()
127+
assert sm.s1.is_active
128+
assert not any("coroutine" in str(w.message) for w in recwarn.list)
129+
130+
131+
async def test_async_condition_not_blocked():
132+
"""Issue #535: 'not cond_true' should block the transition."""
133+
sm = AsyncConditionExpressionMachine()
134+
await sm.activate_initial_state()
135+
with pytest.raises(sm.TransitionNotAllowed):
136+
await sm.go_blocked()
137+
138+
139+
async def test_async_condition_and():
140+
"""Issue #535: 'cond_true and cond_true' should allow the transition."""
141+
sm = AsyncConditionExpressionMachine()
142+
await sm.activate_initial_state()
143+
await sm.go_and()
144+
assert sm.s1.is_active
145+
146+
147+
async def test_async_condition_and_blocked():
148+
"""Issue #535: 'cond_true and cond_false' should block the transition."""
149+
sm = AsyncConditionExpressionMachine()
150+
await sm.activate_initial_state()
151+
with pytest.raises(sm.TransitionNotAllowed):
152+
await sm.go_and_blocked()
153+
154+
155+
async def test_async_condition_or_false_first():
156+
"""Issue #535: 'cond_false or cond_true' should allow the transition."""
157+
sm = AsyncConditionExpressionMachine()
158+
await sm.activate_initial_state()
159+
await sm.go_or_false_first()
160+
assert sm.s1.is_active
161+
162+
163+
async def test_async_condition_or_true_first():
164+
"""'cond_true or cond_false' should allow the transition."""
165+
sm = AsyncConditionExpressionMachine()
166+
await sm.activate_initial_state()
167+
await sm.go_or_true_first()
168+
assert sm.s1.is_active
169+
170+
171+
async def test_async_condition_or_both_false():
172+
"""'cond_false or cond_false' should block the transition."""
173+
sm = AsyncConditionExpressionMachine()
174+
await sm.activate_initial_state()
175+
with pytest.raises(sm.TransitionNotAllowed):
176+
await sm.go_or_both_false()
177+
178+
99179
async def test_async_state_should_be_initialized(async_order_control_machine):
100180
"""Test that the state machine is initialized before any event is triggered
101181

tests/test_copy.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
import pickle
34
from copy import deepcopy
@@ -181,3 +182,51 @@ def test_copy_with_custom_init_and_vars(copy_method):
181182
assert sm2.custom == 1
182183
assert sm2.value == [1, 2, 3]
183184
assert sm2.current_state == MyStateMachine.started
185+
186+
187+
class AsyncTrafficLightMachine(StateMachine):
188+
green = State(initial=True)
189+
yellow = State()
190+
red = State()
191+
192+
cycle = green.to(yellow) | yellow.to(red) | red.to(green)
193+
194+
async def on_enter_state(self, target):
195+
"""Async callback to ensure the SM uses AsyncEngine."""
196+
197+
198+
def test_copy_async_statemachine_before_activation(copy_method):
199+
"""Regression test for issue #544: async SM fails after pickle/deepcopy.
200+
201+
When an async SM is copied before activation, the copy must still be
202+
activatable because ``__setstate__`` re-enqueues the ``__initial__`` event.
203+
"""
204+
sm = AsyncTrafficLightMachine()
205+
sm_copy = copy_method(sm)
206+
207+
async def verify():
208+
await sm_copy.activate_initial_state()
209+
assert sm_copy.current_state == AsyncTrafficLightMachine.green
210+
await sm_copy.cycle()
211+
assert sm_copy.current_state == AsyncTrafficLightMachine.yellow
212+
213+
asyncio.run(verify())
214+
215+
216+
def test_copy_async_statemachine_after_activation(copy_method):
217+
"""Copying an async SM that is already activated preserves its current state."""
218+
219+
async def setup_and_verify():
220+
sm = AsyncTrafficLightMachine()
221+
await sm.activate_initial_state()
222+
await sm.cycle()
223+
assert sm.current_state == AsyncTrafficLightMachine.yellow
224+
225+
sm_copy = copy_method(sm)
226+
227+
await sm_copy.activate_initial_state()
228+
assert sm_copy.current_state == AsyncTrafficLightMachine.yellow
229+
await sm_copy.cycle()
230+
assert sm_copy.current_state == AsyncTrafficLightMachine.red
231+
232+
asyncio.run(setup_and_verify())

tests/test_signature.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functools import partial
33

44
import pytest
5-
65
from statemachine.dispatcher import callable_method
76
from statemachine.signature import SignatureAdapter
87

tests/test_spec_parser.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23

34
import pytest
@@ -247,6 +248,101 @@ def variable_hook(var_name):
247248
("height > 1 and height < 2", True, ["height"]),
248249
],
249250
)
251+
def async_variable_hook(var_name):
252+
"""Variable hook that returns async callables, for testing issue #535."""
253+
values = {
254+
"cond_true": True,
255+
"cond_false": False,
256+
"val_10": 10,
257+
"val_20": 20,
258+
}
259+
260+
value = values.get(var_name, False)
261+
262+
async def decorated(*args, **kwargs):
263+
await asyncio.sleep(0)
264+
return value
265+
266+
decorated.__name__ = var_name
267+
return decorated
268+
269+
270+
@pytest.mark.parametrize(
271+
("expression", "expected"),
272+
[
273+
("not cond_false", True),
274+
("not cond_true", False),
275+
("cond_true and cond_true", True),
276+
("cond_true and cond_false", False),
277+
("cond_false and cond_true", False),
278+
("cond_false or cond_true", True),
279+
("cond_true or cond_false", True),
280+
("cond_false or cond_false", False),
281+
("not cond_false and cond_true", True),
282+
("not (cond_true and cond_false)", True),
283+
("not (cond_false or cond_false)", True),
284+
("cond_true and not cond_false", True),
285+
("val_10 == 10", True),
286+
("val_10 != 20", True),
287+
("val_10 < val_20", True),
288+
("val_20 > val_10", True),
289+
("val_10 >= 10", True),
290+
("val_10 <= val_20", True),
291+
],
292+
)
293+
def test_async_expressions(expression, expected):
294+
"""Issue #535: condition expressions with async predicates must await results."""
295+
parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping)
296+
result = parsed_expr()
297+
assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}"
298+
assert asyncio.run(result) is expected, expression
299+
300+
301+
def mixed_variable_hook(var_name):
302+
"""Variable hook where some vars are sync and some are async."""
303+
sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10}
304+
async_values = {"async_true": True, "async_false": False, "async_20": 20}
305+
306+
if var_name in async_values:
307+
value = async_values[var_name]
308+
309+
async def async_decorated(*args, **kwargs):
310+
await asyncio.sleep(0)
311+
return value
312+
313+
async_decorated.__name__ = var_name
314+
return async_decorated
315+
316+
def sync_decorated(*args, **kwargs):
317+
return sync_values.get(var_name, False)
318+
319+
sync_decorated.__name__ = var_name
320+
return sync_decorated
321+
322+
323+
@pytest.mark.parametrize(
324+
("expression", "expected"),
325+
[
326+
# async left, sync right
327+
("async_true and sync_true", True),
328+
("async_false or sync_true", True),
329+
# sync left, async right
330+
("sync_true and async_true", True),
331+
("sync_false or async_true", True),
332+
("sync_true and async_false", False),
333+
("sync_false or async_false", False),
334+
],
335+
)
336+
def test_mixed_sync_async_expressions(expression, expected):
337+
"""Expressions mixing sync and async predicates must handle both correctly."""
338+
parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping)
339+
result = parsed_expr()
340+
if asyncio.iscoroutine(result):
341+
assert asyncio.run(result) is expected, expression
342+
else:
343+
assert result is expected, expression
344+
345+
250346
@pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once")
251347
def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called):
252348
caplog.set_level(logging.DEBUG, logger="tests")

0 commit comments

Comments
 (0)