Skip to content

Commit fd45edd

Browse files
author
rodrigo.nogueira
committed
feat: Add Generic[TModel] support for type-safe model attribute
This implements Generic support for the StateMachine class, enabling type checkers to infer the correct type of the model attribute. Changes: - Add TModel TypeVar and GenericStateMachineMetaclass in factory.py - Update StateMachine to inherit from Generic[TModel] - Update type hints: model parameter and attribute - Add comprehensive test suite with 6 new tests Benefits: - Type checkers (mypy, pyright, pylance) can now validate model attributes - IDE autocomplete works correctly for model.attribute access - Fully backward compatible with existing non-generic code - Catches attribute errors at type-check time instead of runtime Testing: - 119 core tests passing (including 6 new Generic support tests) - 304 total tests passing - Validated on Python 3.10 and 3.11 - mypy type checking passes without errors Closes #515
1 parent 9a089ed commit fd45edd

3 files changed

Lines changed: 172 additions & 12 deletions

File tree

statemachine/factory.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from typing import TYPE_CHECKING
33
from typing import Any
44
from typing import Dict
5+
from typing import Generic
56
from typing import List
67
from typing import Tuple
8+
from typing import TypeVar
79

810
from . import registry
911
from .event import Event
@@ -17,6 +19,10 @@
1719
from .transition_list import TransitionList
1820

1921

22+
TModel = TypeVar("TModel")
23+
"""TypeVar for the model type in StateMachine."""
24+
25+
2026
class StateMachineMetaclass(type):
2127
"Metaclass for constructing StateMachine classes"
2228

@@ -36,7 +42,9 @@ def __init__(
3642

3743
cls._abstract = True
3844
cls._strict_states = strict_states
39-
cls._events: Dict[Event, None] = {} # used Dict to preserve order and avoid duplicates
45+
cls._events: Dict[Event, None] = (
46+
{}
47+
) # used Dict to preserve order and avoid duplicates
4048
cls._protected_attrs: set = set()
4149
cls._events_to_update: Dict[Event, Event | None] = {}
4250

@@ -98,9 +106,9 @@ def _check_final_states(cls):
98106

99107
if final_state_with_invalid_transitions:
100108
raise InvalidDefinition(
101-
_("Cannot declare transitions from final state. Invalid state(s): {}").format(
102-
[s.id for s in final_state_with_invalid_transitions]
103-
)
109+
_(
110+
"Cannot declare transitions from final state. Invalid state(s): {}"
111+
).format([s.id for s in final_state_with_invalid_transitions])
104112
)
105113

106114
def _check_trap_states(cls):
@@ -133,7 +141,8 @@ def _states_without_path_to_final_states(cls):
133141
return [
134142
state
135143
for state in cls.states
136-
if not state.final and not any(s.final for s in visit_connected_states(state))
144+
if not state.final
145+
and not any(s.final for s in visit_connected_states(state))
137146
]
138147

139148
def _disconnected_states(cls, starting_state):
@@ -259,3 +268,17 @@ def _update_event_references(cls):
259268
@property
260269
def events(self):
261270
return list(self._events)
271+
272+
273+
class GenericStateMachineMetaclass(StateMachineMetaclass, type(Generic)): # type: ignore[misc]
274+
"""
275+
Metaclass that combines StateMachineMetaclass with Generic.
276+
277+
This allows StateMachine to be parameterized with a model type using Generic[TModel],
278+
enabling type checkers to infer the correct type of the `model` attribute.
279+
280+
The type: ignore[misc] is necessary because mypy has limitations with generic metaclasses,
281+
but this pattern works correctly at runtime and with type checkers.
282+
"""
283+
284+
pass

statemachine/statemachine.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING
44
from typing import Any
55
from typing import Dict
6+
from typing import Generic
67
from typing import List
78

89
from .callbacks import SPECS_ALL
@@ -18,7 +19,8 @@
1819
from .exceptions import InvalidDefinition
1920
from .exceptions import InvalidStateValue
2021
from .exceptions import TransitionNotAllowed
21-
from .factory import StateMachineMetaclass
22+
from .factory import GenericStateMachineMetaclass
23+
from .factory import TModel
2224
from .graph import iterate_states_and_transitions
2325
from .i18n import _
2426
from .model import Model
@@ -29,7 +31,7 @@
2931
from .state import State
3032

3133

32-
class StateMachine(metaclass=StateMachineMetaclass):
34+
class StateMachine(Generic[TModel], metaclass=GenericStateMachineMetaclass):
3335
"""
3436
3537
Args:
@@ -68,14 +70,14 @@ class StateMachine(metaclass=StateMachineMetaclass):
6870

6971
def __init__(
7072
self,
71-
model: Any = None,
73+
model: "TModel | None" = None,
7274
state_field: str = "state",
7375
start_value: Any = None,
7476
rtc: bool = True,
7577
allow_event_without_transition: bool = False,
7678
listeners: "List[object] | None" = None,
7779
):
78-
self.model = model if model is not None else Model()
80+
self.model: TModel = model if model is not None else Model() # type: ignore[assignment]
7981
self.state_field = state_field
8082
self.start_value = start_value
8183
self.allow_event_without_transition = allow_event_without_transition
@@ -149,7 +151,9 @@ def __setstate__(self, state):
149151
self._engine = self._get_engine(rtc)
150152

151153
def _get_initial_state(self):
152-
initial_state_value = self.start_value if self.start_value else self.initial_state.value
154+
initial_state_value = (
155+
self.start_value if self.start_value else self.initial_state.value
156+
)
153157
try:
154158
return self.states_map[initial_state_value]
155159
except KeyError as err:
@@ -170,7 +174,9 @@ def bind_events_to(self, *targets):
170174
continue
171175
setattr(target, event, trigger)
172176

173-
def _add_listener(self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL):
177+
def _add_listener(
178+
self, listeners: "Listeners", allowed_references: SpecReference = SPECS_ALL
179+
):
174180
registry = self._callbacks
175181
for visited in iterate_states_and_transitions(self.states):
176182
listeners.resolve(
@@ -292,7 +298,10 @@ def events(self) -> "List[Event]":
292298
@property
293299
def allowed_events(self) -> "List[Event]":
294300
"""List of the current allowed events."""
295-
return [getattr(self, event) for event in self.current_state.transitions.unique_events]
301+
return [
302+
getattr(self, event)
303+
for event in self.current_state.transitions.unique_events
304+
]
296305

297306
def _put_nonblocking(self, trigger_data: TriggerData):
298307
"""Put the trigger on the queue without blocking the caller."""

tests/test_generic_support.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Tests for Generic[TModel] support in StateMachine.
2+
3+
Test that type checkers can infer the correct model type when using Generic[TModel].
4+
"""
5+
6+
import pytest
7+
8+
from statemachine import State
9+
from statemachine import StateMachine
10+
11+
12+
class CustomModel:
13+
"""Custom model for testing"""
14+
15+
def __init__(self):
16+
self.state = None
17+
self.custom_attr = "test_value"
18+
self.counter = 0
19+
20+
21+
class GenericStateMachine(StateMachine[CustomModel]):
22+
"""State machine using Generic[CustomModel] for type safety"""
23+
24+
initial = State("Initial", initial=True)
25+
processing = State("Processing")
26+
final = State("Final", final=True)
27+
28+
start = initial.to(processing)
29+
finish = processing.to(final)
30+
31+
32+
class TestGenericSupport:
33+
"""Test suite for Generic[TModel] support"""
34+
35+
def test_generic_statemachine_with_custom_model(self):
36+
"""Test that StateMachine[CustomModel] works with a custom model instance"""
37+
model = CustomModel()
38+
sm = GenericStateMachine(model=model)
39+
40+
assert sm.model is model
41+
assert sm.model.custom_attr == "test_value"
42+
assert sm.model.counter == 0
43+
44+
def test_generic_statemachine_with_default_model(self):
45+
"""Test that StateMachine[CustomModel] works with default Model()"""
46+
sm = GenericStateMachine()
47+
48+
# Default model should be Model(), not CustomModel
49+
assert sm.model is not None
50+
assert sm.current_state == sm.initial
51+
52+
def test_generic_statemachine_transitions_work(self):
53+
"""Test that transitions work correctly with generic state machine"""
54+
model = CustomModel()
55+
sm = GenericStateMachine(model=model)
56+
57+
assert sm.current_state == sm.initial
58+
59+
sm.start()
60+
assert sm.current_state == sm.processing
61+
62+
sm.finish()
63+
assert sm.current_state == sm.final
64+
65+
def test_generic_statemachine_model_persists_across_transitions(self):
66+
"""Test that model state persists across transitions"""
67+
model = CustomModel()
68+
sm = GenericStateMachine(model=model)
69+
70+
# Modify model
71+
sm.model.counter = 42
72+
sm.model.custom_attr = "modified"
73+
74+
# Transition
75+
sm.start()
76+
77+
# Model state should persist
78+
assert sm.model.counter == 42
79+
assert sm.model.custom_attr == "modified"
80+
81+
def test_backward_compatibility_without_generic(self):
82+
"""Test that traditional usage without Generic still works"""
83+
84+
class TraditionalMachine(StateMachine):
85+
"""Non-generic state machine for backward compatibility"""
86+
87+
idle = State("Idle", initial=True)
88+
running = State("Running")
89+
90+
run = idle.to(running)
91+
92+
sm = TraditionalMachine()
93+
assert sm.current_state == sm.idle
94+
95+
sm.run()
96+
assert sm.current_state == sm.running
97+
98+
def test_multiple_generic_machines_with_different_models(self):
99+
"""Test that different generic machines can use different model types"""
100+
101+
class ModelA:
102+
def __init__(self):
103+
self.state = None
104+
self.value_a = "A"
105+
106+
class ModelB:
107+
def __init__(self):
108+
self.state = None
109+
self.value_b = "B"
110+
111+
class MachineA(StateMachine[ModelA]):
112+
initial = State("Initial", initial=True)
113+
final = State("Final", final=True)
114+
go = initial.to(final)
115+
116+
class MachineB(StateMachine[ModelB]):
117+
start = State("Start", initial=True)
118+
end = State("End", final=True)
119+
advance = start.to(end)
120+
121+
model_a = ModelA()
122+
model_b = ModelB()
123+
124+
sm_a = MachineA(model=model_a)
125+
sm_b = MachineB(model=model_b)
126+
127+
assert sm_a.model.value_a == "A"
128+
assert sm_b.model.value_b == "B"

0 commit comments

Comments
 (0)