|
22 | 22 | Lazy, |
23 | 23 | ModelRegistry, |
24 | 24 | ) |
25 | | -from ccflow.evaluators import GraphEvaluator |
| 25 | +from ccflow.callable import FlowOptions |
| 26 | +from ccflow.evaluators import GraphEvaluator, MemoryCacheEvaluator |
26 | 27 |
|
27 | 28 |
|
28 | 29 | class SimpleContext(ContextBase): |
@@ -996,3 +997,84 @@ class StrListContext(ContextBase): |
996 | 997 | @Flow.model(context_type=StrListContext) |
997 | 998 | def bad(vals: FromContext[list[int]]) -> int: |
998 | 999 | return sum(vals) |
| 1000 | + |
| 1001 | + |
| 1002 | +def test_compute_forwards_options_with_custom_evaluator(): |
| 1003 | + calls = {"count": 0} |
| 1004 | + |
| 1005 | + @Flow.model |
| 1006 | + def counter(value: FromContext[int]) -> int: |
| 1007 | + calls["count"] += 1 |
| 1008 | + return value |
| 1009 | + |
| 1010 | + cache = MemoryCacheEvaluator() |
| 1011 | + model = counter() |
| 1012 | + |
| 1013 | + result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) |
| 1014 | + result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) |
| 1015 | + |
| 1016 | + assert result1.value == 10 |
| 1017 | + assert result2.value == 10 |
| 1018 | + assert calls["count"] == 1 |
| 1019 | + |
| 1020 | + |
| 1021 | +def test_compute_forwards_options_with_graph_evaluator(): |
| 1022 | + @Flow.model |
| 1023 | + def source(value: FromContext[int]) -> int: |
| 1024 | + return value * 10 |
| 1025 | + |
| 1026 | + @Flow.model |
| 1027 | + def root(x: int, bonus: FromContext[int]) -> int: |
| 1028 | + return x + bonus |
| 1029 | + |
| 1030 | + model = root(x=source()) |
| 1031 | + |
| 1032 | + # GraphEvaluator evaluates in topo order; verify _options flows through |
| 1033 | + # and the graph evaluator is actually used (doesn't raise CycleError, computes correctly) |
| 1034 | + result = model.flow.compute( |
| 1035 | + FlowContext(value=3, bonus=7), |
| 1036 | + _options=FlowOptions(evaluator=GraphEvaluator()), |
| 1037 | + ) |
| 1038 | + |
| 1039 | + assert result.value == 37 |
| 1040 | + |
| 1041 | + |
| 1042 | +def test_compute_forwards_options_through_bound_model(): |
| 1043 | + calls = {"count": 0} |
| 1044 | + |
| 1045 | + @Flow.model |
| 1046 | + def add(a: int, b: FromContext[int]) -> int: |
| 1047 | + calls["count"] += 1 |
| 1048 | + return a + b |
| 1049 | + |
| 1050 | + cache = MemoryCacheEvaluator() |
| 1051 | + bound = add(a=10).flow.with_inputs(b=5) |
| 1052 | + |
| 1053 | + result1 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True)) |
| 1054 | + result2 = bound.flow.compute(_options=FlowOptions(evaluator=cache, cacheable=True)) |
| 1055 | + |
| 1056 | + assert result1.value == 15 |
| 1057 | + assert result2.value == 15 |
| 1058 | + assert calls["count"] == 1 |
| 1059 | + |
| 1060 | + |
| 1061 | +def test_compute_forwards_options_for_plain_callable_model(): |
| 1062 | + calls = {"count": 0} |
| 1063 | + |
| 1064 | + class Counter(CallableModel): |
| 1065 | + offset: int |
| 1066 | + |
| 1067 | + @Flow.call |
| 1068 | + def __call__(self, context: SimpleContext) -> GenericResult[int]: |
| 1069 | + calls["count"] += 1 |
| 1070 | + return GenericResult(value=context.value + self.offset) |
| 1071 | + |
| 1072 | + cache = MemoryCacheEvaluator() |
| 1073 | + model = Counter(offset=5) |
| 1074 | + |
| 1075 | + result1 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) |
| 1076 | + result2 = model.flow.compute(value=10, _options=FlowOptions(evaluator=cache, cacheable=True)) |
| 1077 | + |
| 1078 | + assert result1.value == 15 |
| 1079 | + assert result2.value == 15 |
| 1080 | + assert calls["count"] == 1 |
0 commit comments