|
5 | 5 | from enum import IntEnum |
6 | 6 | from enum import IntFlag |
7 | 7 | from enum import auto |
| 8 | +from functools import partial |
| 9 | +from functools import reduce |
8 | 10 | from inspect import isawaitable |
9 | 11 | from inspect import iscoroutinefunction |
| 12 | +from typing import TYPE_CHECKING |
10 | 13 | from typing import Callable |
11 | 14 | from typing import Dict |
12 | 15 | from typing import Generator |
13 | 16 | from typing import Iterable |
14 | 17 | from typing import List |
| 18 | +from typing import Set |
15 | 19 | from typing import Type |
16 | 20 |
|
17 | 21 | from .exceptions import AttrNotFound |
| 22 | +from .exceptions import InvalidDefinition |
18 | 23 | from .i18n import _ |
| 24 | +from .spec_parser import custom_and |
| 25 | +from .spec_parser import operator_mapping |
| 26 | +from .spec_parser import parse_boolean_expr |
19 | 27 | from .utils import ensure_iterable |
20 | 28 |
|
| 29 | +if TYPE_CHECKING: |
| 30 | + from statemachine.dispatcher import Listeners |
| 31 | + |
21 | 32 |
|
22 | 33 | class CallbackPriority(IntEnum): |
23 | 34 | GENERIC = 0 |
@@ -54,6 +65,17 @@ def allways_true(*args, **kwargs): |
54 | 65 | return True |
55 | 66 |
|
56 | 67 |
|
| 68 | +def take_callback(name: str, resolver: "Listeners", not_found_handler: Callable) -> Callable: |
| 69 | + callbacks = list(resolver.search_name(name)) |
| 70 | + if len(callbacks) == 0: |
| 71 | + not_found_handler(name) |
| 72 | + return allways_true |
| 73 | + elif len(callbacks) == 1: |
| 74 | + return callbacks[0] |
| 75 | + else: |
| 76 | + return reduce(custom_and, callbacks) |
| 77 | + |
| 78 | + |
57 | 79 | class CallbackSpec: |
58 | 80 | """Specs about callbacks. |
59 | 81 |
|
@@ -110,22 +132,46 @@ def _update_func(self, func: Callable, attr_name: str): |
110 | 132 | self.reference = SpecReference.CALLABLE |
111 | 133 | self.attr_name = attr_name |
112 | 134 |
|
113 | | - def build(self, resolver) -> Generator["CallbackWrapper", None, None]: |
| 135 | + def _wrap(self, callback): |
| 136 | + condition = self.cond if self.cond is not None else allways_true |
| 137 | + return CallbackWrapper( |
| 138 | + callback=callback, |
| 139 | + condition=condition, |
| 140 | + meta=self, |
| 141 | + unique_key=callback.unique_key, |
| 142 | + ) |
| 143 | + |
| 144 | + def build(self, resolver: "Listeners") -> Generator["CallbackWrapper", None, None]: |
114 | 145 | """ |
115 | 146 | Resolves the `func` into a usable callable. |
116 | 147 |
|
117 | 148 | Args: |
118 | 149 | resolver (callable): A method responsible to build and return a valid callable that |
119 | 150 | can receive arbitrary parameters like `*args, **kwargs`. |
120 | 151 | """ |
121 | | - for callback in resolver.search(self): |
122 | | - condition = self.cond if self.cond is not None else allways_true |
123 | | - yield CallbackWrapper( |
124 | | - callback=callback, |
125 | | - condition=condition, |
126 | | - meta=self, |
127 | | - unique_key=callback.unique_key, |
| 152 | + if ( |
| 153 | + not self.is_convention |
| 154 | + and self.group == CallbackGroup.COND |
| 155 | + and self.reference == SpecReference.NAME |
| 156 | + ): |
| 157 | + names_not_found: Set[str] = set() |
| 158 | + take_callback_partial = partial( |
| 159 | + take_callback, resolver=resolver, not_found_handler=names_not_found.add |
128 | 160 | ) |
| 161 | + try: |
| 162 | + expression = parse_boolean_expr(self.func, take_callback_partial, operator_mapping) |
| 163 | + except SyntaxError as err: |
| 164 | + raise InvalidDefinition( |
| 165 | + _("Failed to parse boolean expression '{}'").format(self.func) |
| 166 | + ) from err |
| 167 | + if not expression or names_not_found: |
| 168 | + self.names_not_found = names_not_found |
| 169 | + return |
| 170 | + yield self._wrap(expression) |
| 171 | + return |
| 172 | + |
| 173 | + for callback in resolver.search(self): |
| 174 | + yield self._wrap(callback) |
129 | 175 |
|
130 | 176 |
|
131 | 177 | class SpecListGrouper: |
@@ -292,15 +338,15 @@ def __repr__(self): |
292 | 338 | def __str__(self): |
293 | 339 | return ", ".join(str(c) for c in self) |
294 | 340 |
|
295 | | - def _add(self, spec: CallbackSpec, resolver: Callable): |
| 341 | + def _add(self, spec: CallbackSpec, resolver: "Listeners"): |
296 | 342 | for callback in spec.build(resolver): |
297 | 343 | if callback.unique_key in self.items_already_seen: |
298 | 344 | continue |
299 | 345 |
|
300 | 346 | self.items_already_seen.add(callback.unique_key) |
301 | 347 | insort(self.items, callback) |
302 | 348 |
|
303 | | - def add(self, items: Iterable[CallbackSpec], resolver: Callable): |
| 349 | + def add(self, items: Iterable[CallbackSpec], resolver: "Listeners"): |
304 | 350 | """Validate configurations""" |
305 | 351 | for item in items: |
306 | 352 | self._add(item, resolver) |
@@ -356,6 +402,12 @@ def check(self, specs: CallbackSpecList): |
356 | 402 | callback for callback in self[meta.group.build_key(specs)] if callback.meta == meta |
357 | 403 | ): |
358 | 404 | continue |
| 405 | + if hasattr(meta, "names_not_found"): |
| 406 | + raise AttrNotFound( |
| 407 | + _("Did not found name '{}' from model or statemachine").format( |
| 408 | + ", ".join(meta.names_not_found) |
| 409 | + ), |
| 410 | + ) |
359 | 411 | raise AttrNotFound( |
360 | 412 | _("Did not found name '{}' from model or statemachine").format(meta.func) |
361 | 413 | ) |
|
0 commit comments