Skip to content

Commit 4654aba

Browse files
authored
feat: add weighted (probabilistic) transitions (#564)
* feat: add weighted (probabilistic) transitions contrib module Add `weighted_transitions()` utility that enables probabilistic transition selection based on relative weights. Works entirely through the existing `cond` guard system with zero engine changes. API: weighted_transitions(source, (target, weight), ..., seed=N) to(target, weight, cond=..., on=..., ...) # for transition kwargs Inspired by PR #539 (@bcorfman).
1 parent fde13d9 commit 4654aba

9 files changed

Lines changed: 1028 additions & 2 deletions

File tree

AGENTS.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,21 @@ uv run pytest -n auto
103103

104104
Coverage is enabled by default.
105105

106+
### Testing both sync and async engines
107+
108+
Use the `sm_runner` fixture (from `tests/conftest.py`) when you need to test the same
109+
statechart on both sync and async engines. It is parametrized with `["sync", "async"]`
110+
and provides `start()` / `send()` helpers that handle engine selection automatically:
111+
112+
```python
113+
async def test_something(self, sm_runner):
114+
sm = await sm_runner.start(MyStateChart)
115+
await sm_runner.send(sm, "some_event")
116+
assert "expected_state" in sm.configuration_values
117+
```
118+
119+
Do **not** manually add async no-op listeners or duplicate test classes — prefer `sm_runner`.
120+
106121
## Linting and formatting
107122

108123
```bash

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ async
1717
mixins
1818
integrations
1919
diagram
20+
weighted_transitions
2021
processing_model
2122
statecharts
2223
api

docs/releases/3.0.0.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,35 @@ flag `validate_disconnected_states: bool = True` that can be used to disable thi
346346
It's already disabled when parsing SCXML files.
347347

348348

349+
### Weighted (probabilistic) transitions
350+
351+
A new contrib module `statemachine.contrib.weighted` provides `weighted_transitions()`,
352+
enabling probabilistic transition selection based on relative weights. This works entirely
353+
through the existing condition system — no engine changes required:
354+
355+
```python
356+
from statemachine.contrib.weighted import weighted_transitions
357+
358+
class GameCharacter(StateChart):
359+
standing = State(initial=True)
360+
shift_weight = State()
361+
adjust_hair = State()
362+
bang_shield = State()
363+
364+
idle = weighted_transitions(
365+
standing,
366+
(shift_weight, 70),
367+
(adjust_hair, 20),
368+
(bang_shield, 10),
369+
seed=42,
370+
)
371+
372+
finish = shift_weight.to(standing) | adjust_hair.to(standing) | bang_shield.to(standing)
373+
```
374+
375+
See {ref}`weighted-transitions` for full documentation.
376+
377+
349378
## Bugfixes in 3.0.0
350379

351380
- Fixes [#XXX](https://github.com/fgmacedo/python-statemachine/issues/XXX).

docs/weighted_transitions.md

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
(weighted-transitions)=
2+
3+
# Weighted transitions
4+
5+
```{versionadded} 3.0.0
6+
```
7+
8+
The `weighted_transitions` utility lets you define **probabilistic transitions** — where
9+
each transition from a state has a relative weight that determines how likely it is to be
10+
selected when the event fires.
11+
12+
This is a contrib module that works entirely through the existing {ref}`guards` system.
13+
No engine modifications are needed.
14+
15+
## Installation
16+
17+
The module is included in the `python-statemachine` package. Import it from the contrib
18+
namespace:
19+
20+
```python
21+
from statemachine.contrib.weighted import weighted_transitions
22+
23+
# Only needed when passing transition kwargs (cond, on, etc.)
24+
from statemachine.contrib.weighted import to
25+
```
26+
27+
## Basic usage
28+
29+
Pass a **source state** followed by `(target, weight)` tuples. The result is a regular
30+
{ref}`TransitionList` that you assign to a class attribute as an event:
31+
32+
```{testsetup}
33+
34+
>>> from statemachine import State, StateChart
35+
>>> from statemachine.contrib.weighted import to, weighted_transitions
36+
37+
```
38+
39+
```py
40+
>>> class GameCharacter(StateChart):
41+
... standing = State(initial=True)
42+
... shift_weight = State()
43+
... adjust_hair = State()
44+
... bang_shield = State()
45+
...
46+
... idle = weighted_transitions(
47+
... standing,
48+
... (shift_weight, 70),
49+
... (adjust_hair, 20),
50+
... (bang_shield, 10),
51+
... seed=42,
52+
... )
53+
...
54+
... finish = (
55+
... shift_weight.to(standing)
56+
... | adjust_hair.to(standing)
57+
... | bang_shield.to(standing)
58+
... )
59+
60+
>>> sm = GameCharacter()
61+
>>> sm.send("idle")
62+
>>> any(
63+
... s in sm.configuration_values
64+
... for s in ("shift_weight", "adjust_hair", "bang_shield")
65+
... )
66+
True
67+
68+
```
69+
70+
When `idle` fires, the engine randomly selects one of the three transitions based on
71+
their relative weights: 70% chance for `shift_weight`, 20% for `adjust_hair`,
72+
10% for `bang_shield`.
73+
74+
## Weights
75+
76+
Weights can be any **positive number** — integers, floats, or a mix of both. They are
77+
relative, not absolute percentages:
78+
79+
```python
80+
# These are equivalent (same 70/20/10 ratio):
81+
idle = weighted_transitions(
82+
standing,
83+
(shift_weight, 70),
84+
(adjust_hair, 20),
85+
(bang_shield, 10),
86+
)
87+
88+
idle = weighted_transitions(
89+
standing,
90+
(shift_weight, 7),
91+
(adjust_hair, 2),
92+
(bang_shield, 1),
93+
)
94+
95+
idle = weighted_transitions(
96+
standing,
97+
(shift_weight, 0.7),
98+
(adjust_hair, 0.2),
99+
(bang_shield, 0.1),
100+
)
101+
```
102+
103+
The tuple format `(target, weight)` follows the standard Python pattern used by
104+
{py:func}`random.choices`.
105+
106+
## Reproducibility with `seed`
107+
108+
Pass a `seed` parameter for deterministic, reproducible sequences — useful for testing:
109+
110+
```python
111+
go = weighted_transitions(
112+
s1,
113+
(s2, 50),
114+
(s3, 50),
115+
seed=42, # same seed always produces the same sequence
116+
)
117+
```
118+
119+
```{note}
120+
The seed initializes a per-group `random.Random` instance that is shared across all
121+
instances of the same state machine class. This means the sequence is deterministic
122+
for a given program execution, but different instances advance the same RNG.
123+
```
124+
125+
## Per-transition options
126+
127+
Use the {func}`~statemachine.contrib.weighted.to` helper to pass transition keyword
128+
arguments (``cond``, ``unless``, ``before``, ``on``, ``after``, …) as natural kwargs.
129+
For simple destinations without extra options, a plain ``(target, weight)`` tuple is
130+
enough — ``to()`` is only needed when you want to customize the transition:
131+
132+
```py
133+
>>> class GuardedWeighted(StateChart):
134+
... idle = State(initial=True)
135+
... walk = State()
136+
... run = State()
137+
...
138+
... move = weighted_transitions(
139+
... idle,
140+
... (walk, 70),
141+
... to(run, 30, cond="has_energy"),
142+
... )
143+
... stop = walk.to(idle) | run.to(idle)
144+
...
145+
... has_energy = True
146+
147+
>>> sm = GuardedWeighted()
148+
149+
```
150+
151+
```{important}
152+
**No fallback when a guard fails.** If the weighted selection picks a transition whose
153+
guard evaluates to ``False``, the event fails — the engine does **not** silently fall back
154+
to another transition. This preserves the probability semantics: a 70/30 split means
155+
exactly that, not "70/30 unless the 30% is blocked, in which case always 100% for
156+
the other".
157+
158+
This behavior follows {ref}`conditions` evaluation: the first transition whose **all**
159+
conditions pass is executed.
160+
```
161+
162+
## Combining with callbacks
163+
164+
All standard {ref}`actions` work with weighted events — `before`, `on`, `after` callbacks
165+
and naming conventions like `on_<event>()`:
166+
167+
```python
168+
class WithCallbacks(StateChart):
169+
s1 = State(initial=True)
170+
s2 = State()
171+
s3 = State()
172+
173+
go = weighted_transitions(s1, (s2, 60), (s3, 40))
174+
back = s2.to(s1) | s3.to(s1)
175+
176+
def on_go(self):
177+
print("go event fired!")
178+
179+
def after_go(self):
180+
print("after go!")
181+
```
182+
183+
## Multiple independent groups
184+
185+
Each call to `weighted_transitions()` creates an independent weighted group with its
186+
own RNG. You can have multiple weighted events on the same state machine:
187+
188+
```python
189+
class MultiGroup(StateChart):
190+
idle = State(initial=True)
191+
walk = State()
192+
run = State()
193+
wave = State()
194+
bow = State()
195+
196+
move = weighted_transitions(idle, (walk, 70), (run, 30), seed=1)
197+
greet = weighted_transitions(idle, (wave, 80), (bow, 20), seed=2)
198+
back = walk.to(idle) | run.to(idle) | wave.to(idle) | bow.to(idle)
199+
```
200+
201+
The `move` and `greet` events use separate RNGs and don't interfere with each other.
202+
203+
## Validation
204+
205+
`weighted_transitions()` validates inputs at class definition time:
206+
207+
- The first argument must be a `State` (the source).
208+
- Each destination must be a `(target_state, weight)` or
209+
`(target_state, weight, kwargs_dict)` tuple.
210+
- Weights must be positive numbers (`int` or `float`).
211+
- At least one destination is required.
212+
213+
```py
214+
>>> weighted_transitions(State(initial=True))
215+
Traceback (most recent call last):
216+
...
217+
ValueError: weighted_transitions() requires at least one (target, weight) destination
218+
219+
>>> s1, s2 = State(initial=True), State()
220+
>>> weighted_transitions(s1, (s2, -5))
221+
Traceback (most recent call last):
222+
...
223+
ValueError: Destination 0: weight must be positive, got -5
224+
225+
>>> weighted_transitions(s1, (s2, "ten"))
226+
Traceback (most recent call last):
227+
...
228+
TypeError: Destination 0: weight must be a positive number, got str
229+
230+
```
231+
232+
## How it works
233+
234+
Under the hood, `weighted_transitions()`:
235+
236+
1. Creates a `_WeightedGroup` holding the weights and a `random.Random` instance.
237+
2. Calls `source.to(target, **kwargs)` for each destination, creating standard
238+
transitions.
239+
3. Attaches a lightweight condition callable to each transition's `cond` list.
240+
4. When the event fires, the engine evaluates conditions in order. The first condition
241+
to run rolls the dice (using `random.choices`) and caches the result. Subsequent
242+
conditions check against the cache.
243+
5. Only the selected transition's condition returns `True` — the engine picks it.
244+
245+
This means weighted transitions are fully compatible with all engine features:
246+
{ref}`actions`, {ref}`validators-and-guards`, {ref}`listeners`, async engines,
247+
and {ref}`diagram generation <diagram>`.

statemachine/contrib/diagram.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,37 @@ def quickchart_write_svg(sm: StateChart, path: str):
296296
f.write(data)
297297

298298

299+
def _find_sm_class(module):
300+
"""Find the first StateChart subclass defined in a module."""
301+
import inspect
302+
303+
for _name, obj in inspect.getmembers(module, inspect.isclass):
304+
if (
305+
issubclass(obj, StateChart)
306+
and obj is not StateChart
307+
and obj.__module__ == module.__name__
308+
):
309+
return obj
310+
return None
311+
312+
299313
def import_sm(qualname):
300314
module_name, class_name = qualname.rsplit(".", 1)
301315
module = importlib.import_module(module_name)
302316
smclass = getattr(module, class_name, None)
303-
if not smclass or not issubclass(smclass, StateChart):
304-
raise ValueError(f"{class_name} is not a subclass of StateMachine")
317+
if smclass is not None and isinstance(smclass, type) and issubclass(smclass, StateChart):
318+
return smclass
319+
320+
# qualname may be a module path without a class name — try importing
321+
# the whole path as a module and find the first StateChart subclass.
322+
try:
323+
module = importlib.import_module(qualname)
324+
except ImportError as err:
325+
raise ValueError(f"{class_name} is not a subclass of StateMachine") from err
326+
327+
smclass = _find_sm_class(module)
328+
if smclass is None:
329+
raise ValueError(f"No StateMachine subclass found in module {qualname!r}")
305330

306331
return smclass
307332

0 commit comments

Comments
 (0)