Skip to content

Commit 4e29771

Browse files
authored
refac: Improved isolation of components; caching results of built-in iscoroutinefunction (#493)
1 parent 5528c3e commit 4e29771

12 files changed

Lines changed: 87 additions & 63 deletions

File tree

.github/workflows/python-package.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
strategy:
1616
fail-fast: false
1717
matrix:
18-
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13.0"]
18+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
1919

2020
steps:
2121
- uses: actions/checkout@v3
@@ -41,7 +41,7 @@ jobs:
4141
# run ruff
4242
#----------------------------------------------
4343
- name: Linter with ruff
44-
if: matrix.python-version == 3.12
44+
if: matrix.python-version == 3.13
4545
run: |
4646
uv run ruff check .
4747
uv run ruff format --check .
@@ -57,7 +57,7 @@ jobs:
5757
#----------------------------------------------
5858
- name: Upload coverage to Codecov
5959
uses: codecov/codecov-action@v4
60-
if: matrix.python-version == 3.12
60+
if: matrix.python-version == 3.13
6161
with:
6262
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
6363
directory: .

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
strategy:
1111
fail-fast: false
1212
matrix:
13-
python-version: ["3.12"]
13+
python-version: ["3.13"]
1414

1515
# Specifying a GitHub environment is optional, but strongly encouraged
1616
environment: release

statemachine/callbacks.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from enum import IntFlag
77
from enum import auto
88
from inspect import isawaitable
9-
from inspect import iscoroutinefunction
109
from typing import TYPE_CHECKING
1110
from typing import Callable
1211
from typing import Dict
@@ -233,7 +232,7 @@ def __init__(
233232
unique_key: str,
234233
) -> None:
235234
self._callback = callback
236-
self._iscoro = iscoroutinefunction(callback)
235+
self._iscoro = getattr(callback, "is_coroutine", False)
237236
self.condition = condition
238237
self.meta = meta
239238
self.unique_key = unique_key
@@ -361,3 +360,24 @@ def async_or_sync(self):
361360
self.has_async_callbacks = any(
362361
callback._iscoro for executor in self._registry.values() for callback in executor
363362
)
363+
364+
def call(self, key: str, *args, **kwargs):
365+
if key not in self._registry:
366+
return []
367+
return self._registry[key].call(*args, **kwargs)
368+
369+
def async_call(self, key: str, *args, **kwargs):
370+
return self._registry[key].async_call(*args, **kwargs)
371+
372+
def all(self, key: str, *args, **kwargs):
373+
if key not in self._registry:
374+
return True
375+
return self._registry[key].all(*args, **kwargs)
376+
377+
def async_all(self, key: str, *args, **kwargs):
378+
return self._registry[key].async_all(*args, **kwargs)
379+
380+
def str(self, key: str) -> str:
381+
if key not in self._registry:
382+
return ""
383+
return str(self._registry[key])

statemachine/contrib/diagram.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class DotGraphMachine:
2929
transition_font_size = "9"
3030
"""Transition font size in points"""
3131

32-
def __init__(self, machine):
32+
def __init__(self, machine: StateMachine):
3333
self.machine = machine
3434

3535
def _get_graph(self):
@@ -69,11 +69,11 @@ def _initial_edge(self):
6969
def _actions_getter(self):
7070
if isinstance(self.machine, StateMachine):
7171

72-
def getter(grouper):
73-
return self.machine._get_callbacks(grouper.key)
72+
def getter(grouper) -> str:
73+
return self.machine._callbacks_registry.str(grouper.key)
7474
else:
7575

76-
def getter(grouper):
76+
def getter(grouper) -> str:
7777
all_names = set(dir(self.machine))
7878
return ", ".join(
7979
str(c) for c in grouper if not c.is_convention or c.func in all_names

statemachine/dispatcher.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,27 @@ def search_name(self, name):
187187

188188

189189
def callable_method(a_callable) -> Callable:
190-
method = SignatureAdapter.wrap(a_callable)
191-
method.__name__ = a_callable.__name__
192-
method.__doc__ = a_callable.__doc__
193-
return method
190+
sig = SignatureAdapter.from_callable(a_callable)
191+
sig_bind_expected = sig.bind_expected
192+
193+
metadata_to_copy = a_callable.func if isinstance(a_callable, partial) else a_callable
194+
195+
if sig.is_coroutine:
196+
197+
async def signature_adapter(*args: Any, **kwargs: Any) -> Any:
198+
ba = sig_bind_expected(*args, **kwargs)
199+
return await a_callable(*ba.args, **ba.kwargs)
200+
else:
201+
202+
def signature_adapter(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
203+
ba = sig_bind_expected(*args, **kwargs)
204+
return a_callable(*ba.args, **ba.kwargs)
205+
206+
signature_adapter.__name__ = metadata_to_copy.__name__
207+
signature_adapter.__doc__ = metadata_to_copy.__doc__
208+
signature_adapter.is_coroutine = sig.is_coroutine # type: ignore[attr-defined]
209+
210+
return signature_adapter
194211

195212

196213
def attr_method(attribute, obj) -> Callable:

statemachine/engines/async_.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,12 @@ async def _trigger(self, trigger_data: TriggerData):
9595

9696
event_data = EventData(trigger_data=trigger_data, transition=transition)
9797
args, kwargs = event_data.args, event_data.extended_kwargs
98-
await self.sm._get_callbacks(transition.validators.key).async_call(*args, **kwargs)
99-
if not await self.sm._get_callbacks(transition.cond.key).async_all(*args, **kwargs):
98+
await self.sm._callbacks_registry.async_call(
99+
transition.validators.key, *args, **kwargs
100+
)
101+
if not await self.sm._callbacks_registry.async_all(
102+
transition.cond.key, *args, **kwargs
103+
):
100104
continue
101105

102106
result = await self._activate(event_data)
@@ -115,19 +119,21 @@ async def _activate(self, event_data: EventData):
115119
source = event_data.state
116120
target = transition.target
117121

118-
result = await self.sm._get_callbacks(transition.before.key).async_call(*args, **kwargs)
122+
result = await self.sm._callbacks_registry.async_call(
123+
transition.before.key, *args, **kwargs
124+
)
119125
if source is not None and not transition.internal:
120-
await self.sm._get_callbacks(source.exit.key).async_call(*args, **kwargs)
126+
await self.sm._callbacks_registry.async_call(source.exit.key, *args, **kwargs)
121127

122-
result += await self.sm._get_callbacks(transition.on.key).async_call(*args, **kwargs)
128+
result += await self.sm._callbacks_registry.async_call(transition.on.key, *args, **kwargs)
123129

124130
self.sm.current_state = target
125131
event_data.state = target
126132
kwargs["state"] = target
127133

128134
if not transition.internal:
129-
await self.sm._get_callbacks(target.enter.key).async_call(*args, **kwargs)
130-
await self.sm._get_callbacks(transition.after.key).async_call(*args, **kwargs)
135+
await self.sm._callbacks_registry.async_call(target.enter.key, *args, **kwargs)
136+
await self.sm._callbacks_registry.async_call(transition.after.key, *args, **kwargs)
131137

132138
if len(result) == 0:
133139
result = None

statemachine/engines/sync.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def _trigger(self, trigger_data: TriggerData):
9898

9999
event_data = EventData(trigger_data=trigger_data, transition=transition)
100100
args, kwargs = event_data.args, event_data.extended_kwargs
101-
self.sm._get_callbacks(transition.validators.key).call(*args, **kwargs)
102-
if not self.sm._get_callbacks(transition.cond.key).all(*args, **kwargs):
101+
self.sm._callbacks_registry.call(transition.validators.key, *args, **kwargs)
102+
if not self.sm._callbacks_registry.all(transition.cond.key, *args, **kwargs):
103103
continue
104104

105105
result = self._activate(event_data)
@@ -118,19 +118,19 @@ def _activate(self, event_data: EventData):
118118
source = event_data.state
119119
target = transition.target
120120

121-
result = self.sm._get_callbacks(transition.before.key).call(*args, **kwargs)
121+
result = self.sm._callbacks_registry.call(transition.before.key, *args, **kwargs)
122122
if source is not None and not transition.internal:
123-
self.sm._get_callbacks(source.exit.key).call(*args, **kwargs)
123+
self.sm._callbacks_registry.call(source.exit.key, *args, **kwargs)
124124

125-
result += self.sm._get_callbacks(transition.on.key).call(*args, **kwargs)
125+
result += self.sm._callbacks_registry.call(transition.on.key, *args, **kwargs)
126126

127127
self.sm.current_state = target
128128
event_data.state = target
129129
kwargs["state"] = target
130130

131131
if not transition.internal:
132-
self.sm._get_callbacks(target.enter.key).call(*args, **kwargs)
133-
self.sm._get_callbacks(transition.after.key).call(*args, **kwargs)
132+
self.sm._callbacks_registry.call(target.enter.key, *args, **kwargs)
133+
self.sm._callbacks_registry.call(transition.after.key, *args, **kwargs)
134134

135135
if len(result) == 0:
136136
result = None

statemachine/signature.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from itertools import chain
77
from types import MethodType
88
from typing import Any
9-
from typing import Callable
109

1110

1211
def _make_key(method):
@@ -44,40 +43,22 @@ def cached_function(cls, method):
4443

4544

4645
class SignatureAdapter(Signature):
47-
@classmethod
48-
def wrap(cls, method) -> Callable:
49-
"""Build a wrapper that adapts the received arguments to the inner ``method`` signature"""
50-
51-
sig = cls.from_callable(method)
52-
sig_bind_expected = sig.bind_expected
53-
54-
metadata_to_copy = method.func if isinstance(method, partial) else method
55-
56-
if iscoroutinefunction(method):
57-
58-
async def signature_adapter(*args: Any, **kwargs: Any) -> Any:
59-
ba = sig_bind_expected(*args, **kwargs)
60-
return await method(*ba.args, **ba.kwargs)
61-
else:
62-
63-
def signature_adapter(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
64-
ba = sig_bind_expected(*args, **kwargs)
65-
return method(*ba.args, **ba.kwargs)
66-
67-
signature_adapter.__name__ = metadata_to_copy.__name__
68-
69-
return signature_adapter
46+
is_coroutine: bool = False
7047

7148
@classmethod
7249
@signature_cache
7350
def from_callable(cls, method):
7451
if hasattr(method, "__signature__"):
7552
sig = method.__signature__
76-
return SignatureAdapter(
53+
adapter = SignatureAdapter(
7754
sig.parameters.values(),
7855
return_annotation=sig.return_annotation,
7956
)
80-
return super().from_callable(method)
57+
else:
58+
adapter = super().from_callable(method)
59+
60+
adapter.is_coroutine = iscoroutinefunction(method)
61+
return adapter
8162

8263
def bind_expected(self, *args: Any, **kwargs: Any) -> BoundArguments: # noqa: C901
8364
"""Get a BoundArguments object, that maps the passed `args`

statemachine/spec_parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def build_expression(node, variable_hook, operator_mapping):
7171

7272
def parse_boolean_expr(expr, variable_hook, operator_mapping):
7373
"""Parses the expression into an AST and build a custom expression tree"""
74+
if expr.strip() == "":
75+
raise SyntaxError("Empty expression")
76+
if "!" not in expr and " " not in expr:
77+
return variable_hook(expr)
7478
expr = replace_operators(expr)
7579
tree = ast.parse(expr, mode="eval")
7680
return build_expression(tree.body, variable_hook, operator_mapping)

statemachine/statemachine.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from .callbacks import SPECS_ALL
1212
from .callbacks import SPECS_SAFE
13-
from .callbacks import CallbacksExecutor
1413
from .callbacks import CallbacksRegistry
1514
from .callbacks import SpecReference
1615
from .dispatcher import Listener
@@ -322,6 +321,3 @@ def send(self, event: str, *args, **kwargs):
322321
if not isawaitable(result):
323322
return result
324323
return run_async_from_sync(result)
325-
326-
def _get_callbacks(self, key) -> CallbacksExecutor:
327-
return self._callbacks_registry[key]

0 commit comments

Comments
 (0)