55from enum import IntEnum
66from enum import IntFlag
77from enum import auto
8- from functools import partial
9- from functools import reduce
108from inspect import isawaitable
119from inspect import iscoroutinefunction
1210from typing import TYPE_CHECKING
1311from typing import Callable
1412from typing import Dict
15- from typing import Generator
16- from typing import Iterable
1713from typing import List
18- from typing import Set
19- from typing import Type
2014
2115from .exceptions import AttrNotFound
22- from .exceptions import InvalidDefinition
2316from .i18n import _
24- from .spec_parser import custom_and
25- from .spec_parser import operator_mapping
26- from .spec_parser import parse_boolean_expr
2717from .utils import ensure_iterable
2818
2919if TYPE_CHECKING :
30- from statemachine .dispatcher import Listeners
20+ from typing import Set
21+
22+
23+ def allways_true (* args , ** kwargs ):
24+ return True
3125
3226
3327class CallbackPriority (IntEnum ):
@@ -61,21 +55,6 @@ def build_key(self, specs: "CallbackSpecList") -> str:
6155 return f"{ self .name } @{ id (specs )} "
6256
6357
64- def allways_true (* args , ** kwargs ):
65- return True
66-
67-
68- def take_callback (name : str , resolver : "Listeners" , not_found_handler : Callable ) -> Callable :
69- callbacks = list (resolver .search_name (name ))
70- if len (callbacks ) == 0 :
71- not_found_handler (name )
72- return allways_true
73- elif len (callbacks ) == 1 :
74- return callbacks [0 ]
75- else :
76- return reduce (custom_and , callbacks )
77-
78-
7958class CallbackSpec :
8059 """Specs about callbacks.
8160
@@ -85,6 +64,9 @@ class CallbackSpec:
8564 before any real call is performed.
8665 """
8766
67+ names_not_found : "Set[str] | None" = None
68+ """List of names that were not found on the model or statemachine"""
69+
8870 def __init__ (
8971 self ,
9072 func ,
@@ -112,6 +94,12 @@ def __init__(
11294 self .reference = SpecReference .NAME
11395 self .attr_name = func
11496
97+ self .may_contain_boolean_expression = (
98+ not self .is_convention
99+ and self .group == CallbackGroup .COND
100+ and self .reference == SpecReference .NAME
101+ )
102+
115103 def __repr__ (self ):
116104 return f"{ type (self ).__name__ } ({ self .func !r} , is_convention={ self .is_convention !r} )"
117105
@@ -132,71 +120,26 @@ def _update_func(self, func: Callable, attr_name: str):
132120 self .reference = SpecReference .CALLABLE
133121 self .attr_name = attr_name
134122
135- def _wrap (self , callback ):
136- condition = self .cond if self .cond is not None else allways_true
137- return CallbackWrapper (
138- callback = callback ,
139- condition = condition ,
140- meta = self ,
141- unique_key = callback .unique_key ,
142- )
143-
144- def build (self , resolver : "Listeners" ) -> Generator ["CallbackWrapper" , None , None ]:
145- """
146- Resolves the `func` into a usable callable.
147-
148- Args:
149- resolver (callable): A method responsible to build and return a valid callable that
150- can receive arbitrary parameters like `*args, **kwargs`.
151- """
152- if (
153- not self .is_convention
154- and self .group == CallbackGroup .COND
155- and self .reference == SpecReference .NAME
156- ):
157- names_not_found : Set [str ] = set ()
158- take_callback_partial = partial (
159- take_callback , resolver = resolver , not_found_handler = names_not_found .add
160- )
161- try :
162- expression = parse_boolean_expr (self .func , take_callback_partial , operator_mapping )
163- except SyntaxError as err :
164- raise InvalidDefinition (
165- _ ("Failed to parse boolean expression '{}'" ).format (self .func )
166- ) from err
167- if not expression or names_not_found :
168- self .names_not_found = names_not_found
169- return
170- yield self ._wrap (expression )
171- return
172-
173- for callback in resolver .search (self ):
174- yield self ._wrap (callback )
175-
176123
177124class SpecListGrouper :
178- def __init__ (
179- self , list : "CallbackSpecList" , group : CallbackGroup , factory = CallbackSpec
180- ) -> None :
125+ def __init__ (self , list : "CallbackSpecList" , group : CallbackGroup ) -> None :
181126 self .list = list
182127 self .group = group
183- self .factory = factory
184128 self .key = group .build_key (list )
185129
186130 def add (self , callbacks , ** kwargs ):
187- self .list .add (callbacks , group = self .group , factory = self . factory , ** kwargs )
131+ self .list .add (callbacks , group = self .group , ** kwargs )
188132 return self
189133
190134 def __call__ (self , callback ):
191- return self .list ._add_unbounded_callback (callback , group = self .group , factory = self . factory )
135+ return self .list ._add_unbounded_callback (callback , group = self .group )
192136
193137 def _add_unbounded_callback (self , func , is_event = False , transitions = None , ** kwargs ):
194138 self .list ._add_unbounded_callback (
195139 func ,
196140 is_event = is_event ,
197141 transitions = transitions ,
198142 group = self .group ,
199- factory = self .factory ,
200143 ** kwargs ,
201144 )
202145
@@ -210,6 +153,7 @@ class CallbackSpecList:
210153 def __init__ (self , factory = CallbackSpec ):
211154 self .items : List [CallbackSpec ] = []
212155 self .conventional_specs = set ()
156+ self ._groupers : Dict [CallbackGroup , SpecListGrouper ] = {}
213157 self .factory = factory
214158
215159 def __repr__ (self ):
@@ -253,15 +197,13 @@ def __iter__(self):
253197 def clear (self ):
254198 self .items = []
255199
256- def grouper (
257- self , group : CallbackGroup , factory : Type [ CallbackSpec ] = CallbackSpec
258- ) -> SpecListGrouper :
259- return SpecListGrouper ( self , group , factory = factory )
200+ def grouper (self , group : CallbackGroup ) -> SpecListGrouper :
201+ if group not in self . _groupers :
202+ self . _groupers [ group ] = SpecListGrouper ( self , group )
203+ return self . _groupers [ group ]
260204
261- def _add (self , func , group : CallbackGroup , factory = None , ** kwargs ):
262- if factory is None :
263- factory = self .factory
264- spec = factory (func , group , ** kwargs )
205+ def _add (self , func , group : CallbackGroup , ** kwargs ):
206+ spec = self .factory (func , group , ** kwargs )
265207
266208 if spec in self .items :
267209 return
@@ -338,19 +280,21 @@ def __repr__(self):
338280 def __str__ (self ):
339281 return ", " .join (str (c ) for c in self )
340282
341- def _add (self , spec : CallbackSpec , resolver : "Listeners" ):
342- for callback in spec .build (resolver ):
343- if callback .unique_key in self .items_already_seen :
344- continue
283+ def add (self , key : str , spec : CallbackSpec , builder : Callable [[], Callable ]):
284+ if key in self .items_already_seen :
285+ return
345286
346- self .items_already_seen .add (callback .unique_key )
347- insort (self .items , callback )
287+ self .items_already_seen .add (key )
348288
349- def add (self , items : Iterable [CallbackSpec ], resolver : "Listeners" ):
350- """Validate configurations"""
351- for item in items :
352- self ._add (item , resolver )
353- return self
289+ condition = spec .cond if spec .cond is not None else allways_true
290+ wrapper = CallbackWrapper (
291+ callback = builder (),
292+ condition = condition ,
293+ meta = spec ,
294+ unique_key = key ,
295+ )
296+
297+ insort (self .items , wrapper )
354298
355299 async def async_call (self , * args , ** kwargs ):
356300 return await asyncio .gather (
@@ -402,7 +346,8 @@ def check(self, specs: CallbackSpecList):
402346 callback for callback in self [meta .group .build_key (specs )] if callback .meta == meta
403347 ):
404348 continue
405- if hasattr (meta , "names_not_found" ):
349+
350+ if meta .names_not_found :
406351 raise AttrNotFound (
407352 _ ("Did not found name '{}' from model or statemachine" ).format (
408353 ", " .join (meta .names_not_found )
0 commit comments