Skip to content

Commit f3886fe

Browse files
author
Nijat Khanbabayev
committed
Allow .compute to take _options
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent 5e32c85 commit f3886fe

2 files changed

Lines changed: 87 additions & 5 deletions

File tree

ccflow/flow_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation
2929

3030
from .base import BaseModel, ContextBase, ResultBase
31-
from .callable import CallableModel, Flow, GraphDepList, WrapperModel
31+
from .callable import CallableModel, Flow, FlowOptions, GraphDepList, WrapperModel
3232
from .context import FlowContext
3333
from .exttypes import PyObjectPath
3434
from .local_persistence import register_ccflow_import_path
@@ -913,20 +913,20 @@ def __init__(self, model: CallableModel):
913913
def _compute_target(self) -> CallableModel:
914914
return self._model
915915

916-
def compute(self, context: Any = _UNSET, /, **kwargs) -> Any:
916+
def compute(self, context: Any = _UNSET, /, _options: Optional[FlowOptions] = None, **kwargs) -> Any:
917917
target = self._compute_target
918918
generated = _generated_model_instance(target)
919919
if generated is not None:
920920
built_context = _build_generated_compute_context(generated, context, kwargs)
921-
return _maybe_auto_unwrap_external_result(target, target(built_context))
921+
return _maybe_auto_unwrap_external_result(target, target(built_context, _options=_options))
922922

923923
if context is not _UNSET and kwargs:
924924
raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.")
925925
if context is _UNSET:
926926
built_context = target.context_type.model_validate(kwargs)
927927
else:
928928
built_context = context if isinstance(context, ContextBase) else target.context_type.model_validate(context)
929-
return _maybe_auto_unwrap_external_result(target, target(built_context))
929+
return _maybe_auto_unwrap_external_result(target, target(built_context, _options=_options))
930930

931931
@property
932932
def context_inputs(self) -> Dict[str, Any]:

ccflow/tests/test_flow_model.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
Lazy,
2323
ModelRegistry,
2424
)
25-
from ccflow.evaluators import GraphEvaluator
25+
from ccflow.callable import FlowOptions
26+
from ccflow.evaluators import GraphEvaluator, MemoryCacheEvaluator
2627

2728

2829
class SimpleContext(ContextBase):
@@ -996,3 +997,84 @@ class StrListContext(ContextBase):
996997
@Flow.model(context_type=StrListContext)
997998
def bad(vals: FromContext[list[int]]) -> int:
998999
return sum(vals)
1000+
1001+
1002+
def test_compute_forwards_options_with_custom_evaluator():
1003+
calls = {"count": 0}
1004+
1005+
@Flow.model
1006+
def counter(value: FromContext[int]) -> int:
1007+
calls["count"] += 1
1008+
return value
1009+
1010+
cache = MemoryCacheEvaluator()
1011+
model = counter()
1012+
1013+
result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True))
1014+
result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True))
1015+
1016+
assert result1.value == 10
1017+
assert result2.value == 10
1018+
assert calls["count"] == 1
1019+
1020+
1021+
def test_compute_forwards_options_with_graph_evaluator():
1022+
@Flow.model
1023+
def source(value: FromContext[int]) -> int:
1024+
return value * 10
1025+
1026+
@Flow.model
1027+
def root(x: int, bonus: FromContext[int]) -> int:
1028+
return x + bonus
1029+
1030+
model = root(x=source())
1031+
1032+
# GraphEvaluator evaluates in topo order; verify _options flows through
1033+
# and the graph evaluator is actually used (doesn't raise CycleError, computes correctly)
1034+
result = model.flow.compute(
1035+
FlowContext(value=3, bonus=7),
1036+
_options=FlowOptions(evaluator=GraphEvaluator()),
1037+
)
1038+
1039+
assert result.value == 37
1040+
1041+
1042+
def test_compute_forwards_options_through_bound_model():
1043+
calls = {"count": 0}
1044+
1045+
@Flow.model
1046+
def add(a: int, b: FromContext[int]) -> int:
1047+
calls["count"] += 1
1048+
return a + b
1049+
1050+
cache = MemoryCacheEvaluator()
1051+
bound = add(a=10).flow.with_inputs(b=5)
1052+
1053+
result1 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True))
1054+
result2 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True))
1055+
1056+
assert result1.value == 15
1057+
assert result2.value == 15
1058+
assert calls["count"] == 1
1059+
1060+
1061+
def test_compute_forwards_options_for_plain_callable_model():
1062+
calls = {"count": 0}
1063+
1064+
class Counter(CallableModel):
1065+
offset: int
1066+
1067+
@Flow.call
1068+
def __call__(self, context: SimpleContext) -> GenericResult[int]:
1069+
calls["count"] += 1
1070+
return GenericResult(value=context.value + self.offset)
1071+
1072+
cache = MemoryCacheEvaluator()
1073+
model = Counter(offset=5)
1074+
1075+
result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True))
1076+
result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True))
1077+
1078+
assert result1.value == 15
1079+
assert result2.value == 15
1080+
assert calls["count"] == 1

0 commit comments

Comments
 (0)