1616import logging
1717from functools import lru_cache , wraps
1818from inspect import Signature , isclass , signature
19- from typing import TYPE_CHECKING , Any , Callable , ClassVar , Dict , Generic , List , Optional , Tuple , Type , TypeVar , Union , cast , get_args , get_origin
19+ from typing import (
20+ TYPE_CHECKING ,
21+ Annotated ,
22+ Any ,
23+ Callable ,
24+ ClassVar ,
25+ Dict ,
26+ Generic ,
27+ List ,
28+ Optional ,
29+ Tuple ,
30+ Type ,
31+ TypeVar ,
32+ Union ,
33+ cast ,
34+ get_args ,
35+ get_origin ,
36+ )
2037
2138from pydantic import BaseModel as PydanticBaseModel , ConfigDict , Field , InstanceOf , PrivateAttr , TypeAdapter , field_validator , model_validator
2239from typing_extensions import override
@@ -65,11 +82,11 @@ def _cached_signature(fn):
6582 return signature (fn )
6683
6784
68- def _callable_qualname (fn : Callable [..., Any ]) -> str :
69- return getattr (fn , "__qualname__" , type (fn ).__qualname__ )
70-
71-
7285def _declared_type_matches (actual : Any , expected : Any ) -> bool :
86+ while get_origin (actual ) is Annotated :
87+ actual = get_args (actual )[0 ]
88+ while get_origin (expected ) is Annotated :
89+ expected = get_args (expected )[0 ]
7390 if isinstance (expected , TypeVar ):
7491 return True
7592 if get_origin (expected ) is Union :
@@ -293,25 +310,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
293310 if not isinstance (model , CallableModel ):
294311 raise TypeError (f"Can only decorate methods on CallableModels (not { type (model )} ) with the flow decorator." )
295312
296- # Check if this is an auto_context decorated method
297- has_auto_context = hasattr (fn , "__auto_context__" )
298- if has_auto_context :
299- method_context_type = fn .__auto_context__
300- else :
301- method_context_type = model .context_type
302-
303- # Validate context type (skip for auto contexts which are always valid ContextBase subclasses)
304- if not has_auto_context :
305- if (not isclass (model .context_type ) or not issubclass (model .context_type , ContextBase )) and not (
306- get_origin (model .context_type ) is Union and type (None ) in get_args (model .context_type )
307- ):
308- raise TypeError (f"Context type { model .context_type } must be a subclass of ContextBase" )
309-
310- # Validate result type - use __result_type__ for auto contexts if available
311- if has_auto_context and hasattr (fn , "__result_type__" ):
312- method_result_type = fn .__result_type__
313- else :
314- method_result_type = model .result_type
313+ method_context_type = getattr (fn , "__auto_context__" , model .context_type )
314+ method_result_type = getattr (fn , "__result_type__" , model .result_type )
315+
316+ if (not isclass (method_context_type ) or not issubclass (method_context_type , ContextBase )) and not (
317+ get_origin (method_context_type ) is Union and type (None ) in get_args (method_context_type )
318+ ):
319+ raise TypeError (f"Context type { method_context_type } must be a subclass of ContextBase" )
315320 if (not isclass (method_result_type ) or not issubclass (method_result_type , ResultBase )) and not (
316321 get_origin (method_result_type ) is Union and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (method_result_type ))
317322 ):
@@ -332,12 +337,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
332337 raise TypeError (f"{ fn .__name__ } () was passed a context and got an unexpected keyword argument '{ next (iter (kwargs .keys ()))} '" )
333338
334339 # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message
335- if not isinstance (context , method_context_type ):
336- if get_origin (method_context_type ) is Union and type (None ) in get_args (method_context_type ):
337- coerce_context_type = [t for t in get_args (method_context_type ) if t is not type (None )][0 ]
338- else :
339- coerce_context_type = method_context_type
340- context = coerce_context_type .model_validate (context )
340+ if get_origin (method_context_type ) is Union and type (None ) in get_args (method_context_type ):
341+ if context is not None :
342+ method_context_type = [t for t in get_args (method_context_type ) if t is not type (None )][0 ]
343+ if not isinstance (context , method_context_type ):
344+ context = method_context_type .model_validate (context )
345+ elif not isinstance (context , method_context_type ):
346+ context = method_context_type .model_validate (context )
341347
342348 if fn != getattr (model .__class__ , fn .__name__ ).__wrapped__ :
343349 # This happens when super().__call__ is used when implementing a CallableModel that derives from another one.
@@ -356,7 +362,6 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
356362 wrap_any .get_options = self .get_options
357363 wrap_any .get_evaluation_context = get_evaluation_context
358364
359- # Preserve auto context attributes for introspection
360365 if hasattr (fn , "__auto_context__" ):
361366 wrap_any .__auto_context__ = fn .__auto_context__
362367 if hasattr (fn , "__result_type__" ):
@@ -476,19 +481,12 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
476481 # with infrastructure expecting DateContext instances.
477482
478483 """
479- # Extract auto_context option (not part of FlowOptions)
480- # Can be: False, True, or a ContextBase subclass
481484 auto_context = kwargs .pop ("auto_context" , False )
482-
483- # Determine if auto_context is enabled and extract parent class if provided
484485 if auto_context is False :
485- auto_context_enabled = False
486486 context_parent = None
487487 elif auto_context is True :
488- auto_context_enabled = True
489- context_parent = None
488+ context_parent = ContextBase
490489 elif isclass (auto_context ) and issubclass (auto_context , ContextBase ):
491- auto_context_enabled = True
492490 context_parent = auto_context
493491 else :
494492 raise TypeError (f"auto_context must be False, True, or a ContextBase subclass, got { auto_context !r} " )
@@ -501,7 +499,7 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
501499 else :
502500 # Arguments to decorator, this is just returning the decorator
503501 # Note that the code below is executed only once
504- if auto_context_enabled :
502+ if context_parent is not None :
505503 # Return a decorator that first applies auto_context, then FlowOptions
506504 def auto_context_decorator (fn : Callable [..., Any ]) -> Callable [..., Any ]:
507505 wrapped = _apply_auto_context (fn , parent = context_parent )
@@ -589,6 +587,13 @@ def load_prices(
589587
590588 return flow_model (* args , ** kwargs )
591589
590+ @staticmethod
591+ def transform (* args , ** kwargs ):
592+ """Decorator that turns a top-level function into a serializable with_inputs() transform factory."""
593+ from .flow_model import flow_transform
594+
595+ return flow_transform (* args , ** kwargs )
596+
592597
593598# *****************************************************************************
594599# Define "Evaluators" and associated types
@@ -965,31 +970,27 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult:
965970 model = MyCallable()
966971 model(x=42, y="hello") # Works with kwargs
967972 """
968- sig = signature (func )
973+ from .flow_model import _callable_qualname , _resolved_flow_signature
974+
975+ sig = _resolved_flow_signature (
976+ func ,
977+ skip_self = True ,
978+ require_return_annotation = True ,
979+ annotation_error_suffix = " when auto_context=True" ,
980+ return_error_suffix = " when auto_context=True" ,
981+ function_name = _callable_qualname (func ),
982+ )
969983 base_class = parent or ContextBase
970984
971- if sig .return_annotation is inspect .Signature .empty :
972- raise TypeError (f"Function { _callable_qualname (func )} must have a return type annotation when auto_context=True" )
973-
974985 # Validate parent fields are in function signature
975986 if parent is not None :
976987 parent_fields = set (parent .model_fields .keys ()) - set (ContextBase .model_fields .keys ())
977- sig_params = set (sig .parameters . keys ()) - { "self" }
988+ sig_params = set (sig .parameters )
978989 missing = parent_fields - sig_params
979990 if missing :
980991 raise TypeError (f"Parent context fields { missing } must be included in function signature" )
981992
982- # Build fields from parameters (skip 'self'), pydantic validates types
983- fields = {}
984- for name , param in sig .parameters .items ():
985- if name == "self" :
986- continue
987- if param .kind in (inspect .Parameter .VAR_POSITIONAL , inspect .Parameter .VAR_KEYWORD ):
988- raise TypeError (f"Function { _callable_qualname (func )} does not support { param .kind .description } when auto_context=True" )
989- if param .annotation is inspect .Parameter .empty :
990- raise TypeError (f"Parameter '{ name } ' must have a type annotation when auto_context=True" )
991- default = ... if param .default is inspect .Parameter .empty else param .default
992- fields [name ] = (param .annotation , default )
993+ fields = {name : (param .annotation , ... if param .default is inspect .Parameter .empty else param .default ) for name , param in sig .parameters .items ()}
993994
994995 # Create auto context class
995996 auto_context_class = create_ccflow_model (f"{ _callable_qualname (func )} _AutoContext" , __base__ = base_class , ** fields )
0 commit comments