Skip to content

Commit 1275cd4

Browse files
authored
feat: Conditionals with boolean algebra (#487)
1 parent ff14d62 commit 1275cd4

7 files changed

Lines changed: 462 additions & 12 deletions

File tree

docs/guards.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,31 @@ unless
4242
* Single condition: `unless="condition"`
4343
* Multiple conditions: `unless=["condition1", "condition2"]`
4444

45+
Conditions also support [Boolean algebra](https://en.wikipedia.org/wiki/Boolean_algebra) expressions, allowing you to use compound logic within transition guards. You can use both standard Python logical operators (`not`, `and`, `or`) as well as classic Boolean algebra symbols:
46+
47+
- `!` for `not`
48+
- `^` for `and`
49+
- `v` for `or`
50+
51+
For example:
52+
53+
```python
54+
start.to(end, cond="frodo_has_ring and gandalf_present or !sauron_alive")
55+
```
56+
57+
Both formats can be used interchangeably, so `!sauron_alive` and `not sauron_alive` are equivalent.
58+
59+
4560
```{seealso}
4661
See {ref}`sphx_glr_auto_examples_air_conditioner_machine.py` for an example of
4762
combining multiple transitions to the same event.
4863
```
4964

65+
```{seealso}
66+
See {ref}`sphx_glr_auto_examples_lor_machine.py` for an example of
67+
using boolean algebra in conditions.
68+
```
69+
5070
```{hint}
5171
In Python, a boolean value is either `True` or `False`. However, there are also specific values that
5272
are considered "**falsy**" and will evaluate as `False` when used in a boolean context.

statemachine/callbacks.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,30 @@
55
from enum import IntEnum
66
from enum import IntFlag
77
from enum import auto
8+
from functools import partial
9+
from functools import reduce
810
from inspect import isawaitable
911
from inspect import iscoroutinefunction
12+
from typing import TYPE_CHECKING
1013
from typing import Callable
1114
from typing import Dict
1215
from typing import Generator
1316
from typing import Iterable
1417
from typing import List
18+
from typing import Set
1519
from typing import Type
1620

1721
from .exceptions import AttrNotFound
22+
from .exceptions import InvalidDefinition
1823
from .i18n import _
24+
from .spec_parser import custom_and
25+
from .spec_parser import operator_mapping
26+
from .spec_parser import parse_boolean_expr
1927
from .utils import ensure_iterable
2028

29+
if TYPE_CHECKING:
30+
from statemachine.dispatcher import Listeners
31+
2132

2233
class CallbackPriority(IntEnum):
2334
GENERIC = 0
@@ -54,6 +65,17 @@ def allways_true(*args, **kwargs):
5465
return True
5566

5667

68+
def take_callback(name: str, resolver: "Listeners", not_found_handler: Callable) -> Callable:
69+
callbacks = list(resolver.search_name(name))
70+
if len(callbacks) == 0:
71+
not_found_handler(name)
72+
return allways_true
73+
elif len(callbacks) == 1:
74+
return callbacks[0]
75+
else:
76+
return reduce(custom_and, callbacks)
77+
78+
5779
class CallbackSpec:
5880
"""Specs about callbacks.
5981
@@ -110,22 +132,46 @@ def _update_func(self, func: Callable, attr_name: str):
110132
self.reference = SpecReference.CALLABLE
111133
self.attr_name = attr_name
112134

113-
def build(self, resolver) -> Generator["CallbackWrapper", None, None]:
135+
def _wrap(self, callback):
136+
condition = self.cond if self.cond is not None else allways_true
137+
return CallbackWrapper(
138+
callback=callback,
139+
condition=condition,
140+
meta=self,
141+
unique_key=callback.unique_key,
142+
)
143+
144+
def build(self, resolver: "Listeners") -> Generator["CallbackWrapper", None, None]:
114145
"""
115146
Resolves the `func` into a usable callable.
116147
117148
Args:
118149
resolver (callable): A method responsible to build and return a valid callable that
119150
can receive arbitrary parameters like `*args, **kwargs`.
120151
"""
121-
for callback in resolver.search(self):
122-
condition = self.cond if self.cond is not None else allways_true
123-
yield CallbackWrapper(
124-
callback=callback,
125-
condition=condition,
126-
meta=self,
127-
unique_key=callback.unique_key,
152+
if (
153+
not self.is_convention
154+
and self.group == CallbackGroup.COND
155+
and self.reference == SpecReference.NAME
156+
):
157+
names_not_found: Set[str] = set()
158+
take_callback_partial = partial(
159+
take_callback, resolver=resolver, not_found_handler=names_not_found.add
128160
)
161+
try:
162+
expression = parse_boolean_expr(self.func, take_callback_partial, operator_mapping)
163+
except SyntaxError as err:
164+
raise InvalidDefinition(
165+
_("Failed to parse boolean expression '{}'").format(self.func)
166+
) from err
167+
if not expression or names_not_found:
168+
self.names_not_found = names_not_found
169+
return
170+
yield self._wrap(expression)
171+
return
172+
173+
for callback in resolver.search(self):
174+
yield self._wrap(callback)
129175

130176

131177
class SpecListGrouper:
@@ -292,15 +338,15 @@ def __repr__(self):
292338
def __str__(self):
293339
return ", ".join(str(c) for c in self)
294340

295-
def _add(self, spec: CallbackSpec, resolver: Callable):
341+
def _add(self, spec: CallbackSpec, resolver: "Listeners"):
296342
for callback in spec.build(resolver):
297343
if callback.unique_key in self.items_already_seen:
298344
continue
299345

300346
self.items_already_seen.add(callback.unique_key)
301347
insort(self.items, callback)
302348

303-
def add(self, items: Iterable[CallbackSpec], resolver: Callable):
349+
def add(self, items: Iterable[CallbackSpec], resolver: "Listeners"):
304350
"""Validate configurations"""
305351
for item in items:
306352
self._add(item, resolver)
@@ -356,6 +402,12 @@ def check(self, specs: CallbackSpecList):
356402
callback for callback in self[meta.group.build_key(specs)] if callback.meta == meta
357403
):
358404
continue
405+
if hasattr(meta, "names_not_found"):
406+
raise AttrNotFound(
407+
_("Did not found name '{}' from model or statemachine").format(
408+
", ".join(meta.names_not_found)
409+
),
410+
)
359411
raise AttrNotFound(
360412
_("Did not found name '{}' from model or statemachine").format(meta.func)
361413
)

statemachine/dispatcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def resolve(
7575

7676
def search(self, spec: "CallbackSpec") -> Generator["Callable", None, None]:
7777
if spec.reference is SpecReference.NAME:
78-
yield from self._search_name(spec.func)
78+
yield from self.search_name(spec.func)
7979
return
8080
elif spec.reference is SpecReference.CALLABLE:
8181
yield self._search_callable(spec)
@@ -111,7 +111,7 @@ def _search_callable(self, spec) -> "Callable":
111111

112112
return callable_method(spec.attr_name, spec.func, None)
113113

114-
def _search_name(self, name) -> Generator["Callable", None, None]:
114+
def search_name(self, name) -> Generator["Callable", None, None]:
115115
for config in self.items:
116116
if name not in config.all_attrs:
117117
continue
@@ -143,6 +143,7 @@ def method(*args, **kwargs):
143143
return getter(obj)
144144

145145
method.unique_key = f"{attribute}@{resolver_id}" # type: ignore[attr-defined]
146+
method.__name__ = attribute
146147
return method
147148

148149

statemachine/spec_parser.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import ast
2+
import re
3+
from typing import Callable
4+
5+
replacements = {"!": "not ", "^": " and ", "v": " or "}
6+
7+
pattern = re.compile(r"\!|\^|\bv\b")
8+
9+
10+
def replace_operators(expr: str) -> str:
11+
# preprocess the expression adding support for classical logical operators
12+
def match_func(match):
13+
return replacements[match.group(0)]
14+
15+
return pattern.sub(match_func, expr)
16+
17+
18+
def custom_not(predicate: Callable) -> Callable:
19+
def decorated(*args, **kwargs) -> bool:
20+
return not predicate(*args, **kwargs)
21+
22+
decorated.__name__ = f"not({predicate.__name__})"
23+
unique_key = getattr(predicate, "unique_key", "")
24+
decorated.unique_key = f"not({unique_key})" # type: ignore[attr-defined]
25+
return decorated
26+
27+
28+
def _unique_key(left, right, operator) -> str:
29+
left_key = getattr(left, "unique_key", "")
30+
right_key = getattr(right, "unique_key", "")
31+
return f"{left_key} {operator} {right_key}"
32+
33+
34+
def custom_and(left: Callable, right: Callable) -> Callable:
35+
def decorated(*args, **kwargs) -> bool:
36+
return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return]
37+
38+
decorated.__name__ = f"({left.__name__} and {right.__name__})"
39+
decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined]
40+
return decorated
41+
42+
43+
def custom_or(left: Callable, right: Callable) -> Callable:
44+
def decorated(*args, **kwargs) -> bool:
45+
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
46+
47+
decorated.__name__ = f"({left.__name__} or {right.__name__})"
48+
decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined]
49+
return decorated
50+
51+
52+
def build_expression(node, variable_hook, operator_mapping):
53+
if isinstance(node, ast.BoolOp):
54+
# Handle `and` / `or` operations
55+
operator_fn = operator_mapping[type(node.op)]
56+
left_expr = build_expression(node.values[0], variable_hook, operator_mapping)
57+
for right in node.values[1:]:
58+
right_expr = build_expression(right, variable_hook, operator_mapping)
59+
left_expr = operator_fn(left_expr, right_expr)
60+
return left_expr
61+
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
62+
# Handle `not` operation
63+
operand_expr = build_expression(node.operand, variable_hook, operator_mapping)
64+
return operator_mapping[type(node.op)](operand_expr)
65+
elif isinstance(node, ast.Name):
66+
# Handle variables by calling the variable_hook
67+
return variable_hook(node.id)
68+
else:
69+
raise ValueError(f"Unsupported expression structure: {node.__class__.__name__}")
70+
71+
72+
def parse_boolean_expr(expr, variable_hook, operator_mapping):
73+
"""Parses the expression into an AST and build a custom expression tree"""
74+
expr = replace_operators(expr)
75+
tree = ast.parse(expr, mode="eval")
76+
return build_expression(tree.body, variable_hook, operator_mapping)
77+
78+
79+
operator_mapping = {ast.Or: custom_or, ast.And: custom_and, ast.Not: custom_not}

tests/examples/lor_machine.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Lord of the Rings Quest - Boolean algebra
3+
=========================================
4+
5+
Example that demonstrates the use of Boolean algebra in conditions.
6+
7+
"""
8+
9+
from statemachine import State
10+
from statemachine import StateMachine
11+
from statemachine.exceptions import TransitionNotAllowed
12+
13+
14+
class LordOfTheRingsQuestStateMachine(StateMachine):
15+
# Define the states
16+
shire = State("In the Shire", initial=True)
17+
bree = State("In Bree")
18+
rivendell = State("At Rivendell")
19+
moria = State("In Moria")
20+
lothlorien = State("In Lothlorien")
21+
mordor = State("In Mordor")
22+
mount_doom = State("At Mount Doom", final=True)
23+
24+
# Define transitions with Boolean conditions
25+
start_journey = shire.to(bree, cond="frodo_has_ring and !sauron_alive")
26+
meet_elves = bree.to(rivendell, cond="gandalf_present and frodo_has_ring")
27+
enter_moria = rivendell.to(moria, cond="orc_army_nearby or frodo_has_ring")
28+
reach_lothlorien = moria.to(lothlorien, cond="!orc_army_nearby")
29+
journey_to_mordor = lothlorien.to(mordor, cond="frodo_has_ring and sam_is_loyal")
30+
destroy_ring = mordor.to(mount_doom, cond="frodo_has_ring and frodo_resists_ring")
31+
32+
# Conditions (attributes representing the state of conditions)
33+
frodo_has_ring: bool = True
34+
sauron_alive: bool = True # Initially, Sauron is alive
35+
gandalf_present: bool = False # Gandalf is not present at the start
36+
orc_army_nearby: bool = False
37+
sam_is_loyal: bool = True
38+
frodo_resists_ring: bool = False # Initially, Frodo is not resisting the ring
39+
40+
41+
# %%
42+
# Playing
43+
44+
quest = LordOfTheRingsQuestStateMachine()
45+
46+
# Track state changes
47+
print(f"Current State: {quest.current_state.id}") # Should start at "shire"
48+
49+
# Step 1: Start the journey
50+
quest.sauron_alive = False # Assume Sauron is no longer alive
51+
try:
52+
quest.start_journey()
53+
print(f"Current State: {quest.current_state.id}") # Should be "bree"
54+
except TransitionNotAllowed:
55+
print("Unable to start journey: conditions not met.")
56+
57+
# Step 2: Meet the elves in Rivendell
58+
quest.gandalf_present = True # Gandalf is now present
59+
try:
60+
quest.meet_elves()
61+
print(f"Current State: {quest.current_state.id}") # Should be "rivendell"
62+
except TransitionNotAllowed:
63+
print("Unable to meet elves: conditions not met.")
64+
65+
# Step 3: Enter Moria
66+
quest.orc_army_nearby = True # Orc army is nearby
67+
try:
68+
quest.enter_moria()
69+
print(f"Current State: {quest.current_state.id}") # Should be "moria"
70+
except TransitionNotAllowed:
71+
print("Unable to enter Moria: conditions not met.")
72+
73+
# Step 4: Reach Lothlorien
74+
quest.orc_army_nearby = False # Orcs are no longer nearby
75+
try:
76+
quest.reach_lothlorien()
77+
print(f"Current State: {quest.current_state.id}") # Should be "lothlorien"
78+
except TransitionNotAllowed:
79+
print("Unable to reach Lothlorien: conditions not met.")
80+
81+
# Step 5: Journey to Mordor
82+
try:
83+
quest.journey_to_mordor()
84+
print(f"Current State: {quest.current_state.id}") # Should be "mordor"
85+
except TransitionNotAllowed:
86+
print("Unable to journey to Mordor: conditions not met.")
87+
88+
# Step 6: Fight with Smeagol
89+
try:
90+
quest.destroy_ring()
91+
print(f"Current State: {quest.current_state.id}") # Should be "mount_doom"
92+
except TransitionNotAllowed:
93+
print("Unable to destroy the ring: conditions not met.")
94+
95+
96+
# Step 7: Destroy the ring at Mount Doom
97+
quest.frodo_resists_ring = True # Frodo is now resisting the ring
98+
try:
99+
quest.destroy_ring()
100+
print(f"Current State: {quest.current_state.id}") # Should be "mount_doom"
101+
except TransitionNotAllowed:
102+
print("Unable to destroy the ring: conditions not met.")

0 commit comments

Comments
 (0)