Skip to content

Commit 300cb34

Browse files
authored
fix: widen type annotations for compound/parallel state transitions (#580)
* fix: widen type annotations for compound/parallel state transitions NestedStateFactory.__new__ returns a State instance at runtime, but mypy sees class Foo(State.Compound) as type[Foo]. Accept NestedStateFactory in the union type of _ToState, _FromState, and the metaclass stubs so that compound/parallel states work as transition targets without type: ignore. Also includes: ai_shell example, invoke/error.execution engine support, event id helpers, and related test coverage.
1 parent f60429d commit 300cb34

13 files changed

Lines changed: 793 additions & 34 deletions

AGENTS.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,21 +180,28 @@ uv run mypy statemachine/ tests/
180180

181181
- **Formatter/Linter:** ruff (line length 99, target Python 3.9)
182182
- **Rules:** pycodestyle, pyflakes, isort, pyupgrade, flake8-comprehensions, flake8-bugbear, flake8-pytest-style
183-
- **Imports:** single-line, sorted by isort
183+
- **Imports:** single-line, sorted by isort. **Always prefer top-level imports** — only use
184+
lazy (in-function) imports when strictly necessary to break circular dependencies
184185
- **Docstrings:** Google convention
185186
- **Naming:** PascalCase for classes, snake_case for functions/methods, UPPER_SNAKE_CASE for constants
186187
- **Type hints:** used throughout; `TYPE_CHECKING` for circular imports
187188
- Pre-commit hooks enforce ruff + mypy + pytest
188189

189190
## Design principles
190191

191-
- **Follow SOLID principles.** In particular:
192+
- **Use GRASP/SOLID patterns to guide decisions.** When refactoring or designing, explicitly
193+
apply patterns like Information Expert, Single Responsibility, and Law of Demeter to decide
194+
where logic belongs — don't just pick a convenient location.
195+
- **Information Expert (GRASP):** Place logic in the module/class that already has the
196+
knowledge it needs. If a method computes a result, it should signal or return it rather
197+
than forcing another method to recompute the same thing.
192198
- **Law of Demeter:** Methods should depend only on the data they need, not on the
193199
objects that contain it. Pass the specific value (e.g., a `Future`) rather than the
194200
parent object (e.g., `TriggerData`) — this reduces coupling and removes the need for
195201
null-checks on intermediate accessors.
196202
- **Single Responsibility:** Each module, class, and function should have one clear reason
197-
to change.
203+
to change. Functions and types belong in the module that owns their domain (e.g.,
204+
event-name helpers belong in `event.py`, not in `factory.py`).
198205
- **Interface Segregation:** Depend on narrow interfaces. If a helper only needs one field
199206
from a dataclass, accept that field directly.
200207
- **Decouple infrastructure from domain:** Modules like `signature.py` and `dispatcher.py` are

statemachine/engines/async_.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ async def processing_loop( # noqa: C901
402402

403403
# Spawn invoke handlers for states entered during this macrostep.
404404
await self._invoke_manager.spawn_pending_async()
405+
self._check_root_final_state()
405406

406407
# Phase 2: remaining internal events
407408
while not self.internal_queue.is_empty(): # pragma: no cover

statemachine/engines/base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, sm: "StateChart"):
9999
self._macrostep_count: int = 0
100100
self._microstep_count: int = 0
101101
self._log_id = f"[{type(sm).__name__}]"
102+
self._root_parallel_final_pending: "State | None" = None
102103

103104
def empty(self): # pragma: no cover
104105
return self.external_queue.is_empty()
@@ -614,6 +615,8 @@ def _handle_final_state(self, target: State, on_entry_result: list):
614615
BoundEvent(f"done.state.{grandparent.id}", _sm=self.sm, internal=True).put(
615616
*donedata_args, **donedata_kwargs
616617
)
618+
if grandparent.parent is None:
619+
self._root_parallel_final_pending = grandparent
617620

618621
def _enter_states( # noqa: C901
619622
self,
@@ -908,6 +911,29 @@ def add_ancestor_states_to_enter(
908911
default_history_content,
909912
)
910913

914+
def _check_root_final_state(self):
915+
"""SCXML spec: terminate when the root configuration is final.
916+
917+
For top-level parallel states, the machine terminates when all child
918+
regions have reached their final states — equivalent to the SCXML
919+
algorithm's ``isInFinalState(scxml_element)`` check.
920+
921+
Uses a flag set by ``_handle_final_state`` (Information Expert) to
922+
avoid re-scanning top-level states on every macrostep. The flag is
923+
deferred because ``done.state`` events queued by ``_handle_final_state``
924+
may trigger transitions that exit the parallel, so we verify the
925+
parallel is still in the configuration before terminating.
926+
"""
927+
state = self._root_parallel_final_pending
928+
if state is None:
929+
return
930+
self._root_parallel_final_pending = None
931+
# A done.state transition may have exited the parallel; verify it's
932+
# still in the configuration before terminating.
933+
if state in self.sm.configuration and self.is_in_final_state(state):
934+
self._invoke_manager.cancel_all()
935+
self.running = False
936+
911937
def is_in_final_state(self, state: State) -> bool:
912938
if state.is_compound:
913939
return any(s.final and s in self.sm.configuration for s in state.states)

statemachine/engines/sync.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def processing_loop(self, caller_future=None): # noqa: C901
118118

119119
# Spawn invoke handlers for states entered during this macrostep.
120120
self._invoke_manager.spawn_pending_sync()
121+
self._check_root_final_state()
121122

122123
# Process remaining internal events before external events.
123124
# Note: the macrostep loop above already drains the internal queue,

statemachine/event.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@
1616
from .transition_list import TransitionList
1717

1818

19+
def _expand_event_id(key: str) -> str:
20+
"""Apply naming conventions for special event prefixes.
21+
22+
Converts underscore-based Python attribute names to their dot-separated
23+
event equivalents. Returns a space-separated string so ``Events.add()``
24+
registers both forms.
25+
"""
26+
if key.startswith("done_invoke_"):
27+
suffix = key[len("done_invoke_") :]
28+
return f"{key} done.invoke.{suffix}"
29+
if key.startswith("done_state_"):
30+
suffix = key[len("done_state_") :]
31+
return f"{key} done.state.{suffix}"
32+
if key.startswith("error_"):
33+
return f"{key} {key.replace('_', '.')}"
34+
return key
35+
36+
1937
_event_data_kwargs = {
2038
"event_data",
2139
"machine",

statemachine/factory.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .callbacks import CallbackPriority
1010
from .callbacks import CallbackSpecList
1111
from .event import Event
12+
from .event import _expand_event_id
1213
from .exceptions import InvalidDefinition
1314
from .graph import disconnected_states
1415
from .graph import iterate_states
@@ -271,29 +272,13 @@ def add_from_attributes(cls, attrs): # noqa: C901
271272
if isinstance(value, State):
272273
cls.add_state(key, value)
273274
elif isinstance(value, (Transition, TransitionList)):
274-
event_id = key
275-
if key.startswith("error_"):
276-
event_id = f"{key} {key.replace('_', '.')}"
277-
elif key.startswith("done_invoke_"):
278-
suffix = key[len("done_invoke_") :]
279-
event_id = f"{key} done.invoke.{suffix}"
280-
elif key.startswith("done_state_"):
281-
suffix = key[len("done_state_") :]
282-
event_id = f"{key} done.state.{suffix}"
275+
event_id = _expand_event_id(key)
283276
cls.add_event(event=Event(transitions=value, id=event_id, name=key))
284277
elif isinstance(value, (Event,)):
285278
if value._has_real_id:
286279
event_id = value.id
287-
elif key.startswith("error_"):
288-
event_id = f"{key} {key.replace('_', '.')}"
289-
elif key.startswith("done_invoke_"):
290-
suffix = key[len("done_invoke_") :]
291-
event_id = f"{key} done.invoke.{suffix}"
292-
elif key.startswith("done_state_"):
293-
suffix = key[len("done_state_") :]
294-
event_id = f"{key} done.state.{suffix}"
295280
else:
296-
event_id = key
281+
event_id = _expand_event_id(key)
297282
new_event = Event(
298283
transitions=value._transitions,
299284
id=event_id,

statemachine/state.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from typing import Dict
55
from typing import Generator
66
from typing import List
7+
from typing import cast
78
from weakref import ref
89

910
from .callbacks import CallbackGroup
1011
from .callbacks import CallbackPriority
1112
from .callbacks import CallbackSpecList
13+
from .event import _expand_event_id
1214
from .exceptions import InvalidDefinition
1315
from .exceptions import StateMachineError
1416
from .i18n import _
@@ -32,8 +34,10 @@ def __call__(self, *states: "State", **kwargs):
3234

3335

3436
class _ToState(_TransitionBuilder):
35-
def __call__(self, *states: "State | None", **kwargs):
36-
transitions = TransitionList(Transition(self._state, state, **kwargs) for state in states)
37+
def __call__(self, *states: "State | NestedStateFactory | None", **kwargs):
38+
transitions = TransitionList(
39+
Transition(self._state, cast("State | None", state), **kwargs) for state in states
40+
)
3741
self._state.transitions.add_transitions(transitions)
3842
return transitions
3943

@@ -43,11 +47,12 @@ def any(self, **kwargs):
4347
"""Create transitions from all non-final states (reversed)."""
4448
return self.__call__(AnyState(), **kwargs)
4549

46-
def __call__(self, *states: "State", **kwargs):
50+
def __call__(self, *states: "State | NestedStateFactory", **kwargs):
4751
transitions = TransitionList()
4852
for origin in states:
49-
transition = Transition(origin, self._state, **kwargs)
50-
origin.transitions.add_transitions(transition)
53+
state = cast(State, origin)
54+
transition = Transition(state, self._state, **kwargs)
55+
state.transitions.add_transitions(transition)
5156
transitions.add_transitions(transition)
5257
return transitions
5358

@@ -78,7 +83,7 @@ def __new__( # type: ignore [misc]
7883
value._set_id(key)
7984
states.append(value)
8085
elif isinstance(value, TransitionList):
81-
value.add_event(key)
86+
value.add_event(_expand_event_id(key))
8287
elif callable(value):
8388
callbacks[key] = value
8489

@@ -87,15 +92,17 @@ def __new__( # type: ignore [misc]
8792
)
8893

8994
@classmethod
90-
def to(cls, *args: "State", **kwargs) -> "_ToState": # pragma: no cover
95+
def to(cls, *args: "State | NestedStateFactory", **kwargs) -> "_ToState": # pragma: no cover
9196
"""Create transitions to the given target states.
9297
.. note: This method is only a type hint for mypy.
9398
The actual implementation belongs to the :ref:`State` class.
9499
"""
95100
return _ToState(State())
96101

97102
@classmethod
98-
def from_(cls, *args: "State", **kwargs) -> "_FromState": # pragma: no cover
103+
def from_( # pragma: no cover
104+
cls, *args: "State | NestedStateFactory", **kwargs
105+
) -> "_FromState":
99106
"""Create transitions from the given target states (reversed).
100107
.. note: This method is only a type hint for mypy.
101108
The actual implementation belongs to the :ref:`State` class.

0 commit comments

Comments
 (0)