|
| 1 | +(weighted-transitions)= |
| 2 | + |
| 3 | +# Weighted transitions |
| 4 | + |
| 5 | +```{versionadded} 3.0.0 |
| 6 | +``` |
| 7 | + |
| 8 | +The `weighted_transitions` utility lets you define **probabilistic transitions** — where |
| 9 | +each transition from a state has a relative weight that determines how likely it is to be |
| 10 | +selected when the event fires. |
| 11 | + |
| 12 | +This is a contrib module that works entirely through the existing {ref}`guards` system. |
| 13 | +No engine modifications are needed. |
| 14 | + |
| 15 | +## Installation |
| 16 | + |
| 17 | +The module is included in the `python-statemachine` package. Import it from the contrib |
| 18 | +namespace: |
| 19 | + |
| 20 | +```python |
| 21 | +from statemachine.contrib.weighted import weighted_transitions |
| 22 | + |
| 23 | +# Only needed when passing transition kwargs (cond, on, etc.) |
| 24 | +from statemachine.contrib.weighted import to |
| 25 | +``` |
| 26 | + |
| 27 | +## Basic usage |
| 28 | + |
| 29 | +Pass a **source state** followed by `(target, weight)` tuples. The result is a regular |
| 30 | +{ref}`TransitionList` that you assign to a class attribute as an event: |
| 31 | + |
| 32 | +```{testsetup} |
| 33 | +
|
| 34 | +>>> from statemachine import State, StateChart |
| 35 | +>>> from statemachine.contrib.weighted import to, weighted_transitions |
| 36 | +
|
| 37 | +``` |
| 38 | + |
| 39 | +```py |
| 40 | +>>> class GameCharacter(StateChart): |
| 41 | +... standing = State(initial=True) |
| 42 | +... shift_weight = State() |
| 43 | +... adjust_hair = State() |
| 44 | +... bang_shield = State() |
| 45 | +... |
| 46 | +... idle = weighted_transitions( |
| 47 | +... standing, |
| 48 | +... (shift_weight, 70), |
| 49 | +... (adjust_hair, 20), |
| 50 | +... (bang_shield, 10), |
| 51 | +... seed=42, |
| 52 | +... ) |
| 53 | +... |
| 54 | +... finish = ( |
| 55 | +... shift_weight.to(standing) |
| 56 | +... | adjust_hair.to(standing) |
| 57 | +... | bang_shield.to(standing) |
| 58 | +... ) |
| 59 | + |
| 60 | +>>> sm = GameCharacter() |
| 61 | +>>> sm.send("idle") |
| 62 | +>>> any( |
| 63 | +... s in sm.configuration_values |
| 64 | +... for s in ("shift_weight", "adjust_hair", "bang_shield") |
| 65 | +... ) |
| 66 | +True |
| 67 | + |
| 68 | +``` |
| 69 | + |
| 70 | +When `idle` fires, the engine randomly selects one of the three transitions based on |
| 71 | +their relative weights: 70% chance for `shift_weight`, 20% for `adjust_hair`, |
| 72 | +10% for `bang_shield`. |
| 73 | + |
| 74 | +## Weights |
| 75 | + |
| 76 | +Weights can be any **positive number** — integers, floats, or a mix of both. They are |
| 77 | +relative, not absolute percentages: |
| 78 | + |
| 79 | +```python |
| 80 | +# These are equivalent (same 70/20/10 ratio): |
| 81 | +idle = weighted_transitions( |
| 82 | + standing, |
| 83 | + (shift_weight, 70), |
| 84 | + (adjust_hair, 20), |
| 85 | + (bang_shield, 10), |
| 86 | +) |
| 87 | + |
| 88 | +idle = weighted_transitions( |
| 89 | + standing, |
| 90 | + (shift_weight, 7), |
| 91 | + (adjust_hair, 2), |
| 92 | + (bang_shield, 1), |
| 93 | +) |
| 94 | + |
| 95 | +idle = weighted_transitions( |
| 96 | + standing, |
| 97 | + (shift_weight, 0.7), |
| 98 | + (adjust_hair, 0.2), |
| 99 | + (bang_shield, 0.1), |
| 100 | +) |
| 101 | +``` |
| 102 | + |
| 103 | +The tuple format `(target, weight)` follows the standard Python pattern used by |
| 104 | +{py:func}`random.choices`. |
| 105 | + |
| 106 | +## Reproducibility with `seed` |
| 107 | + |
| 108 | +Pass a `seed` parameter for deterministic, reproducible sequences — useful for testing: |
| 109 | + |
| 110 | +```python |
| 111 | +go = weighted_transitions( |
| 112 | + s1, |
| 113 | + (s2, 50), |
| 114 | + (s3, 50), |
| 115 | + seed=42, # same seed always produces the same sequence |
| 116 | +) |
| 117 | +``` |
| 118 | + |
| 119 | +```{note} |
| 120 | +The seed initializes a per-group `random.Random` instance that is shared across all |
| 121 | +instances of the same state machine class. This means the sequence is deterministic |
| 122 | +for a given program execution, but different instances advance the same RNG. |
| 123 | +``` |
| 124 | + |
| 125 | +## Per-transition options |
| 126 | + |
| 127 | +Use the {func}`~statemachine.contrib.weighted.to` helper to pass transition keyword |
| 128 | +arguments (``cond``, ``unless``, ``before``, ``on``, ``after``, …) as natural kwargs. |
| 129 | +For simple destinations without extra options, a plain ``(target, weight)`` tuple is |
| 130 | +enough — ``to()`` is only needed when you want to customize the transition: |
| 131 | + |
| 132 | +```py |
| 133 | +>>> class GuardedWeighted(StateChart): |
| 134 | +... idle = State(initial=True) |
| 135 | +... walk = State() |
| 136 | +... run = State() |
| 137 | +... |
| 138 | +... move = weighted_transitions( |
| 139 | +... idle, |
| 140 | +... (walk, 70), |
| 141 | +... to(run, 30, cond="has_energy"), |
| 142 | +... ) |
| 143 | +... stop = walk.to(idle) | run.to(idle) |
| 144 | +... |
| 145 | +... has_energy = True |
| 146 | + |
| 147 | +>>> sm = GuardedWeighted() |
| 148 | + |
| 149 | +``` |
| 150 | + |
| 151 | +```{important} |
| 152 | +**No fallback when a guard fails.** If the weighted selection picks a transition whose |
| 153 | +guard evaluates to ``False``, the event fails — the engine does **not** silently fall back |
| 154 | +to another transition. This preserves the probability semantics: a 70/30 split means |
| 155 | +exactly that, not "70/30 unless the 30% is blocked, in which case always 100% for |
| 156 | +the other". |
| 157 | +
|
| 158 | +This behavior follows {ref}`conditions` evaluation: the first transition whose **all** |
| 159 | +conditions pass is executed. |
| 160 | +``` |
| 161 | + |
| 162 | +## Combining with callbacks |
| 163 | + |
| 164 | +All standard {ref}`actions` work with weighted events — `before`, `on`, `after` callbacks |
| 165 | +and naming conventions like `on_<event>()`: |
| 166 | + |
| 167 | +```python |
| 168 | +class WithCallbacks(StateChart): |
| 169 | + s1 = State(initial=True) |
| 170 | + s2 = State() |
| 171 | + s3 = State() |
| 172 | + |
| 173 | + go = weighted_transitions(s1, (s2, 60), (s3, 40)) |
| 174 | + back = s2.to(s1) | s3.to(s1) |
| 175 | + |
| 176 | + def on_go(self): |
| 177 | + print("go event fired!") |
| 178 | + |
| 179 | + def after_go(self): |
| 180 | + print("after go!") |
| 181 | +``` |
| 182 | + |
| 183 | +## Multiple independent groups |
| 184 | + |
| 185 | +Each call to `weighted_transitions()` creates an independent weighted group with its |
| 186 | +own RNG. You can have multiple weighted events on the same state machine: |
| 187 | + |
| 188 | +```python |
| 189 | +class MultiGroup(StateChart): |
| 190 | + idle = State(initial=True) |
| 191 | + walk = State() |
| 192 | + run = State() |
| 193 | + wave = State() |
| 194 | + bow = State() |
| 195 | + |
| 196 | + move = weighted_transitions(idle, (walk, 70), (run, 30), seed=1) |
| 197 | + greet = weighted_transitions(idle, (wave, 80), (bow, 20), seed=2) |
| 198 | + back = walk.to(idle) | run.to(idle) | wave.to(idle) | bow.to(idle) |
| 199 | +``` |
| 200 | + |
| 201 | +The `move` and `greet` events use separate RNGs and don't interfere with each other. |
| 202 | + |
| 203 | +## Validation |
| 204 | + |
| 205 | +`weighted_transitions()` validates inputs at class definition time: |
| 206 | + |
| 207 | +- The first argument must be a `State` (the source). |
| 208 | +- Each destination must be a `(target_state, weight)` or |
| 209 | + `(target_state, weight, kwargs_dict)` tuple. |
| 210 | +- Weights must be positive numbers (`int` or `float`). |
| 211 | +- At least one destination is required. |
| 212 | + |
| 213 | +```py |
| 214 | +>>> weighted_transitions(State(initial=True)) |
| 215 | +Traceback (most recent call last): |
| 216 | + ... |
| 217 | +ValueError: weighted_transitions() requires at least one (target, weight) destination |
| 218 | + |
| 219 | +>>> s1, s2 = State(initial=True), State() |
| 220 | +>>> weighted_transitions(s1, (s2, -5)) |
| 221 | +Traceback (most recent call last): |
| 222 | + ... |
| 223 | +ValueError: Destination 0: weight must be positive, got -5 |
| 224 | + |
| 225 | +>>> weighted_transitions(s1, (s2, "ten")) |
| 226 | +Traceback (most recent call last): |
| 227 | + ... |
| 228 | +TypeError: Destination 0: weight must be a positive number, got str |
| 229 | + |
| 230 | +``` |
| 231 | + |
| 232 | +## How it works |
| 233 | + |
| 234 | +Under the hood, `weighted_transitions()`: |
| 235 | + |
| 236 | +1. Creates a `_WeightedGroup` holding the weights and a `random.Random` instance. |
| 237 | +2. Calls `source.to(target, **kwargs)` for each destination, creating standard |
| 238 | + transitions. |
| 239 | +3. Attaches a lightweight condition callable to each transition's `cond` list. |
| 240 | +4. When the event fires, the engine evaluates conditions in order. The first condition |
| 241 | + to run rolls the dice (using `random.choices`) and caches the result. Subsequent |
| 242 | + conditions check against the cache. |
| 243 | +5. Only the selected transition's condition returns `True` — the engine picks it. |
| 244 | + |
| 245 | +This means weighted transitions are fully compatible with all engine features: |
| 246 | +{ref}`actions`, {ref}`validators-and-guards`, {ref}`listeners`, async engines, |
| 247 | +and {ref}`diagram generation <diagram>`. |
0 commit comments