Skip to content

Commit 72a1207

Browse files
l46kokcopybara-github
authored andcommitted
Prevent non-foldable functions from being folded in comprehensions
PiperOrigin-RevId: 865997662
1 parent 026cae6 commit 72a1207

2 files changed

Lines changed: 12 additions & 5 deletions

File tree

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) {
201201
CelNavigableMutableExpr operand = navigableExpr.children().collect(onlyElement());
202202
return areChildrenArgConstant(operand);
203203
case COMPREHENSION:
204-
return !isNestedComprehension(navigableExpr);
204+
return !isNestedComprehension(navigableExpr) && containsFoldableFunctionOnly(navigableExpr);
205205
default:
206206
return false;
207207
}

optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,24 @@
4747
@RunWith(TestParameterInjector.class)
4848
public class ConstantFoldingOptimizerTest {
4949
private static final CelOptions CEL_OPTIONS =
50-
CelOptions.current()
51-
.enableTimestampEpoch(true)
52-
.build();
50+
CelOptions.current().populateMacroCalls(true).enableTimestampEpoch(true).build();
5351
private static final Cel CEL =
5452
CelFactory.standardCelBuilder()
5553
.addVar("x", SimpleType.DYN)
5654
.addVar("y", SimpleType.DYN)
5755
.addVar("list_var", ListType.create(SimpleType.STRING))
5856
.addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING))
57+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
5958
.addFunctionDeclarations(
6059
CelFunctionDecl.newFunctionDeclaration(
6160
"get_true",
62-
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
61+
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)),
62+
CelFunctionDecl.newFunctionDeclaration(
63+
"get_list",
64+
CelOverloadDecl.newGlobalOverload(
65+
"get_list_overload",
66+
ListType.create(SimpleType.INT),
67+
ListType.create(SimpleType.INT))))
6368
.addFunctionBindings(
6469
CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true))
6570
.addMessageTypes(TestAllTypes.getDescriptor())
@@ -371,6 +376,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
371376
@TestParameters("{source: 'x == 42'}")
372377
@TestParameters("{source: 'timestamp(100)'}")
373378
@TestParameters("{source: 'duration(\"1h\")'}")
379+
@TestParameters("{source: '[true].exists(x, x == get_true())'}")
380+
@TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}")
374381
public void constantFold_noOp(String source) throws Exception {
375382
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
376383

0 commit comments

Comments
 (0)