Skip to content

Commit a5a1392

Browse files
authored
macedo/refac dispatcher (#490)
* refac: Removing registry's dependency of Listeners * refac: Move all callback search to the dispatcher module
1 parent 8985849 commit a5a1392

6 files changed

Lines changed: 163 additions & 173 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ convention = "google"
155155

156156
[tool.coverage.run]
157157
branch = true
158+
dynamic_context = "test_function"
158159
relative_files = true
159160
data_file = ".coverage"
160161
source = ["statemachine"]
@@ -177,3 +178,4 @@ exclude_lines = [
177178

178179
[tool.coverage.html]
179180
directory = "tmp/htmlcov"
181+
show_contexts = true

statemachine/callbacks.py

Lines changed: 39 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,23 @@
55
from enum import IntEnum
66
from enum import IntFlag
77
from enum import auto
8-
from functools import partial
9-
from functools import reduce
108
from inspect import isawaitable
119
from inspect import iscoroutinefunction
1210
from typing import TYPE_CHECKING
1311
from typing import Callable
1412
from typing import Dict
15-
from typing import Generator
16-
from typing import Iterable
1713
from typing import List
18-
from typing import Set
19-
from typing import Type
2014

2115
from .exceptions import AttrNotFound
22-
from .exceptions import InvalidDefinition
2316
from .i18n import _
24-
from .spec_parser import custom_and
25-
from .spec_parser import operator_mapping
26-
from .spec_parser import parse_boolean_expr
2717
from .utils import ensure_iterable
2818

2919
if TYPE_CHECKING:
30-
from statemachine.dispatcher import Listeners
20+
from typing import Set
21+
22+
23+
def allways_true(*args, **kwargs):
24+
return True
3125

3226

3327
class CallbackPriority(IntEnum):
@@ -61,21 +55,6 @@ def build_key(self, specs: "CallbackSpecList") -> str:
6155
return f"{self.name}@{id(specs)}"
6256

6357

64-
def allways_true(*args, **kwargs):
65-
return True
66-
67-
68-
def take_callback(name: str, resolver: "Listeners", not_found_handler: Callable) -> Callable:
69-
callbacks = list(resolver.search_name(name))
70-
if len(callbacks) == 0:
71-
not_found_handler(name)
72-
return allways_true
73-
elif len(callbacks) == 1:
74-
return callbacks[0]
75-
else:
76-
return reduce(custom_and, callbacks)
77-
78-
7958
class CallbackSpec:
8059
"""Specs about callbacks.
8160
@@ -85,6 +64,9 @@ class CallbackSpec:
8564
before any real call is performed.
8665
"""
8766

67+
names_not_found: "Set[str] | None" = None
68+
"""List of names that were not found on the model or statemachine"""
69+
8870
def __init__(
8971
self,
9072
func,
@@ -112,6 +94,12 @@ def __init__(
11294
self.reference = SpecReference.NAME
11395
self.attr_name = func
11496

97+
self.may_contain_boolean_expression = (
98+
not self.is_convention
99+
and self.group == CallbackGroup.COND
100+
and self.reference == SpecReference.NAME
101+
)
102+
115103
def __repr__(self):
116104
return f"{type(self).__name__}({self.func!r}, is_convention={self.is_convention!r})"
117105

@@ -132,71 +120,26 @@ def _update_func(self, func: Callable, attr_name: str):
132120
self.reference = SpecReference.CALLABLE
133121
self.attr_name = attr_name
134122

135-
def _wrap(self, callback):
136-
condition = self.cond if self.cond is not None else allways_true
137-
return CallbackWrapper(
138-
callback=callback,
139-
condition=condition,
140-
meta=self,
141-
unique_key=callback.unique_key,
142-
)
143-
144-
def build(self, resolver: "Listeners") -> Generator["CallbackWrapper", None, None]:
145-
"""
146-
Resolves the `func` into a usable callable.
147-
148-
Args:
149-
resolver (callable): A method responsible to build and return a valid callable that
150-
can receive arbitrary parameters like `*args, **kwargs`.
151-
"""
152-
if (
153-
not self.is_convention
154-
and self.group == CallbackGroup.COND
155-
and self.reference == SpecReference.NAME
156-
):
157-
names_not_found: Set[str] = set()
158-
take_callback_partial = partial(
159-
take_callback, resolver=resolver, not_found_handler=names_not_found.add
160-
)
161-
try:
162-
expression = parse_boolean_expr(self.func, take_callback_partial, operator_mapping)
163-
except SyntaxError as err:
164-
raise InvalidDefinition(
165-
_("Failed to parse boolean expression '{}'").format(self.func)
166-
) from err
167-
if not expression or names_not_found:
168-
self.names_not_found = names_not_found
169-
return
170-
yield self._wrap(expression)
171-
return
172-
173-
for callback in resolver.search(self):
174-
yield self._wrap(callback)
175-
176123

177124
class SpecListGrouper:
178-
def __init__(
179-
self, list: "CallbackSpecList", group: CallbackGroup, factory=CallbackSpec
180-
) -> None:
125+
def __init__(self, list: "CallbackSpecList", group: CallbackGroup) -> None:
181126
self.list = list
182127
self.group = group
183-
self.factory = factory
184128
self.key = group.build_key(list)
185129

186130
def add(self, callbacks, **kwargs):
187-
self.list.add(callbacks, group=self.group, factory=self.factory, **kwargs)
131+
self.list.add(callbacks, group=self.group, **kwargs)
188132
return self
189133

190134
def __call__(self, callback):
191-
return self.list._add_unbounded_callback(callback, group=self.group, factory=self.factory)
135+
return self.list._add_unbounded_callback(callback, group=self.group)
192136

193137
def _add_unbounded_callback(self, func, is_event=False, transitions=None, **kwargs):
194138
self.list._add_unbounded_callback(
195139
func,
196140
is_event=is_event,
197141
transitions=transitions,
198142
group=self.group,
199-
factory=self.factory,
200143
**kwargs,
201144
)
202145

@@ -210,6 +153,7 @@ class CallbackSpecList:
210153
def __init__(self, factory=CallbackSpec):
211154
self.items: List[CallbackSpec] = []
212155
self.conventional_specs = set()
156+
self._groupers: Dict[CallbackGroup, SpecListGrouper] = {}
213157
self.factory = factory
214158

215159
def __repr__(self):
@@ -253,15 +197,13 @@ def __iter__(self):
253197
def clear(self):
254198
self.items = []
255199

256-
def grouper(
257-
self, group: CallbackGroup, factory: Type[CallbackSpec] = CallbackSpec
258-
) -> SpecListGrouper:
259-
return SpecListGrouper(self, group, factory=factory)
200+
def grouper(self, group: CallbackGroup) -> SpecListGrouper:
201+
if group not in self._groupers:
202+
self._groupers[group] = SpecListGrouper(self, group)
203+
return self._groupers[group]
260204

261-
def _add(self, func, group: CallbackGroup, factory=None, **kwargs):
262-
if factory is None:
263-
factory = self.factory
264-
spec = factory(func, group, **kwargs)
205+
def _add(self, func, group: CallbackGroup, **kwargs):
206+
spec = self.factory(func, group, **kwargs)
265207

266208
if spec in self.items:
267209
return
@@ -338,19 +280,21 @@ def __repr__(self):
338280
def __str__(self):
339281
return ", ".join(str(c) for c in self)
340282

341-
def _add(self, spec: CallbackSpec, resolver: "Listeners"):
342-
for callback in spec.build(resolver):
343-
if callback.unique_key in self.items_already_seen:
344-
continue
283+
def add(self, key: str, spec: CallbackSpec, builder: Callable[[], Callable]):
284+
if key in self.items_already_seen:
285+
return
345286

346-
self.items_already_seen.add(callback.unique_key)
347-
insort(self.items, callback)
287+
self.items_already_seen.add(key)
348288

349-
def add(self, items: Iterable[CallbackSpec], resolver: "Listeners"):
350-
"""Validate configurations"""
351-
for item in items:
352-
self._add(item, resolver)
353-
return self
289+
condition = spec.cond if spec.cond is not None else allways_true
290+
wrapper = CallbackWrapper(
291+
callback=builder(),
292+
condition=condition,
293+
meta=spec,
294+
unique_key=key,
295+
)
296+
297+
insort(self.items, wrapper)
354298

355299
async def async_call(self, *args, **kwargs):
356300
return await asyncio.gather(
@@ -402,7 +346,8 @@ def check(self, specs: CallbackSpecList):
402346
callback for callback in self[meta.group.build_key(specs)] if callback.meta == meta
403347
):
404348
continue
405-
if hasattr(meta, "names_not_found"):
349+
350+
if meta.names_not_found:
406351
raise AttrNotFound(
407352
_("Did not found name '{}' from model or statemachine").format(
408353
", ".join(meta.names_not_found)

0 commit comments

Comments
 (0)