Skip to content

Commit d22ffc4

Browse files
author
Nijat Khanbabayev
committed
Simplify Flow.model code, update docs
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent f3886fe commit d22ffc4

File tree

8 files changed

+389
-514
lines changed

8 files changed

+389
-514
lines changed

ccflow/callable.py

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,24 @@
1616
import logging
1717
from functools import lru_cache, wraps
1818
from inspect import Signature, isclass, signature
19-
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args, get_origin
19+
from typing import (
20+
TYPE_CHECKING,
21+
Annotated,
22+
Any,
23+
Callable,
24+
ClassVar,
25+
Dict,
26+
Generic,
27+
List,
28+
Optional,
29+
Tuple,
30+
Type,
31+
TypeVar,
32+
Union,
33+
cast,
34+
get_args,
35+
get_origin,
36+
)
2037

2138
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator
2239
from typing_extensions import override
@@ -65,11 +82,11 @@ def _cached_signature(fn):
6582
return signature(fn)
6683

6784

68-
def _callable_qualname(fn: Callable[..., Any]) -> str:
69-
return getattr(fn, "__qualname__", type(fn).__qualname__)
70-
71-
7285
def _declared_type_matches(actual: Any, expected: Any) -> bool:
86+
while get_origin(actual) is Annotated:
87+
actual = get_args(actual)[0]
88+
while get_origin(expected) is Annotated:
89+
expected = get_args(expected)[0]
7390
if isinstance(expected, TypeVar):
7491
return True
7592
if get_origin(expected) is Union:
@@ -293,25 +310,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
293310
if not isinstance(model, CallableModel):
294311
raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.")
295312

296-
# Check if this is an auto_context decorated method
297-
has_auto_context = hasattr(fn, "__auto_context__")
298-
if has_auto_context:
299-
method_context_type = fn.__auto_context__
300-
else:
301-
method_context_type = model.context_type
302-
303-
# Validate context type (skip for auto contexts which are always valid ContextBase subclasses)
304-
if not has_auto_context:
305-
if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not (
306-
get_origin(model.context_type) is Union and type(None) in get_args(model.context_type)
307-
):
308-
raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase")
309-
310-
# Validate result type - use __result_type__ for auto contexts if available
311-
if has_auto_context and hasattr(fn, "__result_type__"):
312-
method_result_type = fn.__result_type__
313-
else:
314-
method_result_type = model.result_type
313+
method_context_type = getattr(fn, "__auto_context__", model.context_type)
314+
method_result_type = getattr(fn, "__result_type__", model.result_type)
315+
316+
if (not isclass(method_context_type) or not issubclass(method_context_type, ContextBase)) and not (
317+
get_origin(method_context_type) is Union and type(None) in get_args(method_context_type)
318+
):
319+
raise TypeError(f"Context type {method_context_type} must be a subclass of ContextBase")
315320
if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not (
316321
get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type))
317322
):
@@ -332,12 +337,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
332337
raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'")
333338

334339
# Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message
335-
if not isinstance(context, method_context_type):
336-
if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type):
337-
coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0]
338-
else:
339-
coerce_context_type = method_context_type
340-
context = coerce_context_type.model_validate(context)
340+
if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type):
341+
if context is not None:
342+
method_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0]
343+
if not isinstance(context, method_context_type):
344+
context = method_context_type.model_validate(context)
345+
elif not isinstance(context, method_context_type):
346+
context = method_context_type.model_validate(context)
341347

342348
if fn != getattr(model.__class__, fn.__name__).__wrapped__:
343349
# This happens when super().__call__ is used when implementing a CallableModel that derives from another one.
@@ -356,7 +362,6 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] =
356362
wrap_any.get_options = self.get_options
357363
wrap_any.get_evaluation_context = get_evaluation_context
358364

359-
# Preserve auto context attributes for introspection
360365
if hasattr(fn, "__auto_context__"):
361366
wrap_any.__auto_context__ = fn.__auto_context__
362367
if hasattr(fn, "__result_type__"):
@@ -476,19 +481,12 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
476481
# with infrastructure expecting DateContext instances.
477482
478483
"""
479-
# Extract auto_context option (not part of FlowOptions)
480-
# Can be: False, True, or a ContextBase subclass
481484
auto_context = kwargs.pop("auto_context", False)
482-
483-
# Determine if auto_context is enabled and extract parent class if provided
484485
if auto_context is False:
485-
auto_context_enabled = False
486486
context_parent = None
487487
elif auto_context is True:
488-
auto_context_enabled = True
489-
context_parent = None
488+
context_parent = ContextBase
490489
elif isclass(auto_context) and issubclass(auto_context, ContextBase):
491-
auto_context_enabled = True
492490
context_parent = auto_context
493491
else:
494492
raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}")
@@ -501,7 +499,7 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
501499
else:
502500
# Arguments to decorator, this is just returning the decorator
503501
# Note that the code below is executed only once
504-
if auto_context_enabled:
502+
if context_parent is not None:
505503
# Return a decorator that first applies auto_context, then FlowOptions
506504
def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
507505
wrapped = _apply_auto_context(fn, parent=context_parent)
@@ -589,6 +587,13 @@ def load_prices(
589587

590588
return flow_model(*args, **kwargs)
591589

590+
@staticmethod
591+
def transform(*args, **kwargs):
592+
"""Decorator that turns a top-level function into a serializable with_inputs() transform factory."""
593+
from .flow_model import flow_transform
594+
595+
return flow_transform(*args, **kwargs)
596+
592597

593598
# *****************************************************************************
594599
# Define "Evaluators" and associated types
@@ -965,31 +970,27 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult:
965970
model = MyCallable()
966971
model(x=42, y="hello") # Works with kwargs
967972
"""
968-
sig = signature(func)
973+
from .flow_model import _callable_qualname, _resolved_flow_signature
974+
975+
sig = _resolved_flow_signature(
976+
func,
977+
skip_self=True,
978+
require_return_annotation=True,
979+
annotation_error_suffix=" when auto_context=True",
980+
return_error_suffix=" when auto_context=True",
981+
function_name=_callable_qualname(func),
982+
)
969983
base_class = parent or ContextBase
970984

971-
if sig.return_annotation is inspect.Signature.empty:
972-
raise TypeError(f"Function {_callable_qualname(func)} must have a return type annotation when auto_context=True")
973-
974985
# Validate parent fields are in function signature
975986
if parent is not None:
976987
parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys())
977-
sig_params = set(sig.parameters.keys()) - {"self"}
988+
sig_params = set(sig.parameters)
978989
missing = parent_fields - sig_params
979990
if missing:
980991
raise TypeError(f"Parent context fields {missing} must be included in function signature")
981992

982-
# Build fields from parameters (skip 'self'), pydantic validates types
983-
fields = {}
984-
for name, param in sig.parameters.items():
985-
if name == "self":
986-
continue
987-
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
988-
raise TypeError(f"Function {_callable_qualname(func)} does not support {param.kind.description} when auto_context=True")
989-
if param.annotation is inspect.Parameter.empty:
990-
raise TypeError(f"Parameter '{name}' must have a type annotation when auto_context=True")
991-
default = ... if param.default is inspect.Parameter.empty else param.default
992-
fields[name] = (param.annotation, default)
993+
fields = {name: (param.annotation, ... if param.default is inspect.Parameter.empty else param.default) for name, param in sig.parameters.items()}
993994

994995
# Create auto context class
995996
auto_context_class = create_ccflow_model(f"{_callable_qualname(func)}_AutoContext", __base__=base_class, **fields)

0 commit comments

Comments
 (0)