Skip to content

Commit a214e04

Browse files
authored
feat: add pyright support, Generic[TModel], remove __getattr__ catch-all (#566)
* feat: add pyright support, Generic[TModel] for typed models, remove __getattr__ catch-all - Add pyright as dev dependency + pre-commit hook, configured with basic type checking targeting Python 3.9 - Make StateChart generic over TModel so `sm.model` gets proper type inference and IDE autocompletion when a model class is provided - Remove TYPE_CHECKING `__getattr__` stubs from both StateChart and StateMachineMetaclass — type checkers now detect misspelled attributes and unresolved references on subclasses - Add explicit type declarations for metaclass-set attributes (name, id, states, states_map, initial_state, final_states) with docstrings - Add return type annotations throughout (send, raise_, enabled_events, activate_initial_state, Event.__call__, run_async_from_sync, etc.) - Fix all 38 baseline pyright errors with proper types, assertions for weakref derefs, and targeted type: ignore for genuinely dynamic APIs (pydot, partial attributes, conditional function definitions) - Update tests/examples to use configuration_values for enum-based state checks instead of relying on dynamic attribute access - Document typed models in docs/models.md and docs/releases/3.0.0.md Closes #515
1 parent c6445c8 commit a214e04

23 files changed

Lines changed: 229 additions & 80 deletions

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ repos:
2525
types: [python]
2626
language: system
2727
pass_filenames: false
28+
- id: pyright
29+
name: Pyright
30+
entry: uv run pyright statemachine/
31+
types: [python]
32+
language: system
33+
pass_filenames: false
2834
- id: pytest
2935
name: Pytest
3036
entry: uv run pytest -n auto

docs/models.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,45 @@ provided to the built-in {ref}`StateChart`, such as implementing all {ref}`actio
2525
{ref}`guards` on your domain model and keeping only the definition of {ref}`states` and
2626
{ref}`transitions` on the {ref}`StateChart`.
2727
```
28+
29+
## Typed models
30+
31+
`StateChart` supports a generic type parameter so that type checkers (mypy, pyright) and IDEs
32+
can infer the type of `sm.model` and provide code completion.
33+
34+
Declare your model class and pass it as a type parameter to `StateChart`:
35+
36+
```python
37+
>>> from statemachine import State, StateChart
38+
39+
>>> class OrderModel:
40+
... order_id: str = ""
41+
... total: float = 0.0
42+
... def confirm(self):
43+
... return f"Order {self.order_id} confirmed: ${self.total}"
44+
45+
>>> class OrderWorkflow(StateChart["OrderModel"]):
46+
... draft = State(initial=True)
47+
... confirmed = State(final=True)
48+
... confirm = draft.to(confirmed, on="on_confirm")
49+
... def on_confirm(self):
50+
... return self.model.confirm()
51+
52+
>>> model = OrderModel()
53+
>>> model.order_id = "A-123"
54+
>>> model.total = 49.90
55+
>>> sm = OrderWorkflow(model=model)
56+
57+
>>> sm.send("confirm")
58+
'Order A-123 confirmed: $49.9'
59+
60+
```
61+
62+
With this declaration, `sm.model` is typed as `OrderModel` instead of `Any`, so
63+
`sm.model.order_id`, `sm.model.total`, and `sm.model.confirm()` all get full
64+
autocompletion and type checking in your IDE.
65+
66+
```{note}
67+
When no type parameter is given (e.g. `class MySM(StateChart)`), the model defaults
68+
to `Any`, preserving full backward compatibility.
69+
```

docs/releases/3.0.0.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,41 @@ flag `validate_disconnected_states: bool = True` that can be used to disable thi
346346
It's already disabled when parsing SCXML files.
347347

348348

349+
### Typed models with `Generic[TModel]`
350+
351+
`StateChart` now supports a generic type parameter for the model, enabling full type
352+
inference and IDE autocompletion on `sm.model`:
353+
354+
```py
355+
>>> from statemachine import State, StateChart
356+
357+
>>> class MyModel:
358+
... name: str = ""
359+
... value: int = 0
360+
361+
>>> class MySM(StateChart["MyModel"]):
362+
... idle = State(initial=True)
363+
... active = State(final=True)
364+
... go = idle.to(active)
365+
366+
>>> sm = MySM(model=MyModel())
367+
>>> sm.model.name
368+
''
369+
370+
```
371+
372+
With this declaration, type checkers infer `sm.model` as `MyModel` (not `Any`), so
373+
accessing `sm.model.name` or `sm.model.value` gets full autocompletion and type safety.
374+
When no type parameter is given, `StateChart` defaults to `StateChart[Any]` for backward
375+
compatibility. See {ref}`domain models` for details.
376+
377+
### Improved type checking with pyright
378+
379+
The library now supports [pyright](https://github.com/microsoft/pyright) in addition to mypy.
380+
Type annotations have been improved throughout the codebase, and a catch-all `__getattr__`
381+
that previously returned `Any` has been removed — type checkers can now detect misspelled
382+
attribute names and unresolved references on `StateChart` subclasses.
383+
349384
### Weighted (probabilistic) transitions
350385

351386
A new contrib module `statemachine.contrib.weighted` provides `weighted_transitions()`,

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ dev = [
5656
"babel >=2.16.0; python_version >='3.8'",
5757
"pytest-xdist>=3.6.1",
5858
"pytest-timeout>=2.3.1",
59+
"pyright>=1.1.400",
5960
]
6061

6162
[build-system]
@@ -201,3 +202,6 @@ fixture-parentheses = true
201202
mark-parentheses = true
202203

203204
[tool.pyright]
205+
pythonVersion = "3.9"
206+
typeCheckingMode = "basic"
207+
include = ["statemachine"]

statemachine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from .state import State
44
from .statemachine import StateChart
55
from .statemachine import StateMachine
6+
from .statemachine import TModel
67

78
__author__ = """Fernando Macedo"""
89
__email__ = "fgmacedo@gmail.com"
910
__version__ = "2.6.0"
1011

11-
__all__ = ["StateChart", "StateMachine", "State", "HistoryState", "Event"]
12+
__all__ = ["StateChart", "StateMachine", "State", "HistoryState", "Event", "TModel"]

statemachine/callbacks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def __init__(
9696
name = func.func.__name__ if is_partial else func.__name__
9797
self.attr_name = name if not self.is_event or self.is_bounded else f"_{name}_"
9898
if not self.is_bounded:
99-
func.attr_name = self.attr_name
100-
func.is_event = is_event
99+
func.attr_name = self.attr_name # type: ignore[union-attr]
100+
func.is_event = is_event # type: ignore[union-attr]
101101
else:
102102
self.reference = SpecReference.NAME
103103
self.attr_name = func
@@ -270,7 +270,7 @@ class CallbacksExecutor:
270270
"""A list of callbacks that can be executed in order."""
271271

272272
def __init__(self):
273-
self.items: List[CallbackWrapper] = deque()
273+
self.items: "deque[CallbackWrapper]" = deque()
274274
self.items_already_seen = set()
275275

276276
def __iter__(self):

statemachine/contrib/diagram.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _initial_node(self, state):
6969
width=0.2,
7070
height=0.2,
7171
)
72-
node.set_fillcolor("black")
72+
node.set_fillcolor("black") # type: ignore[attr-defined]
7373
return node
7474

7575
def _initial_edge(self, initial_node, state):
@@ -89,7 +89,7 @@ def _initial_edge(self, initial_node, state):
8989
def _actions_getter(self):
9090
if isinstance(self.machine, StateChart):
9191

92-
def getter(grouper):
92+
def getter(grouper): # pyright: ignore[reportRedeclaration]
9393
return self.machine._callbacks.str(grouper.key)
9494
else:
9595

@@ -162,10 +162,10 @@ def _state_as_node(self, state):
162162
isinstance(self.machine, StateChart)
163163
and state.value in self.machine.configuration_values
164164
):
165-
node.set_penwidth(self.state_active_penwidth)
166-
node.set_fillcolor(self.state_active_fillcolor)
165+
node.set_penwidth(self.state_active_penwidth) # type: ignore[attr-defined]
166+
node.set_fillcolor(self.state_active_fillcolor) # type: ignore[attr-defined]
167167
else:
168-
node.set_fillcolor("white")
168+
node.set_fillcolor("white") # type: ignore[attr-defined]
169169
return node
170170

171171
def _transition_as_edges(self, transition):

statemachine/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def callable_method(a_callable) -> Callable:
194194

195195
if sig.is_coroutine:
196196

197-
async def signature_adapter(*args: Any, **kwargs: Any) -> Any:
197+
async def signature_adapter(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration]
198198
ba = sig_bind_expected(*args, **kwargs)
199199
return await a_callable(*ba.args, **ba.kwargs)
200200
else:

statemachine/event.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import TYPE_CHECKING
2+
from typing import Any
23
from typing import List
34
from typing import cast
45
from uuid import uuid4
@@ -11,6 +12,7 @@
1112

1213
if TYPE_CHECKING:
1314
from .statemachine import StateChart
15+
from .transition import Transition
1416
from .transition_list import TransitionList
1517

1618

@@ -57,7 +59,7 @@ class Event(AddCallbacksMixin, str):
5759

5860
def __new__(
5961
cls,
60-
transitions: "str | TransitionList | None" = None,
62+
transitions: "str | Transition | TransitionList | None" = None,
6163
id: "str | None" = None,
6264
name: "str | None" = None,
6365
delay: float = 0,
@@ -82,7 +84,7 @@ def __new__(
8284
else:
8385
instance.name = ""
8486
if transitions:
85-
instance._transitions = transitions
87+
instance._transitions = transitions # type: ignore[assignment]
8688
instance._has_real_id = _has_real_id
8789
instance._sm = _sm
8890
return instance
@@ -144,7 +146,7 @@ def build_trigger(self, *args, machine: "StateChart", send_id: "str | None" = No
144146

145147
return trigger_data
146148

147-
def __call__(self, *args, **kwargs):
149+
def __call__(self, *args, **kwargs) -> Any:
148150
"""Send this event to the current state machine.
149151
150152
Triggering an event on a state machine means invoking or sending a signal, initiating the
@@ -155,7 +157,7 @@ def __call__(self, *args, **kwargs):
155157
# an SM instance. Such SM instance is provided by `__get__` method when
156158
# used as a property descriptor.
157159
self.put(*args, **kwargs)
158-
return self._sm._processing_loop() # type: ignore
160+
return self._sm._processing_loop() # type: ignore[union-attr]
159161

160162
def split( # type: ignore[override]
161163
self, sep: "str | None" = None, maxsplit: int = -1

statemachine/events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def add(self, events):
3535
return self
3636

3737
def match(self, event: "str | None"):
38-
if event is None and self.is_empty:
39-
return True
38+
if event is None:
39+
return self.is_empty
4040
return any(e.match(event) for e in self)
4141

4242
def _replace(self, old, new):

0 commit comments

Comments
 (0)