Skip to content

Commit b3ee6a6

Browse files
l46kokcopybara-github
authored andcommitted
Fix wrapper types to properly unwrap in lists
PiperOrigin-RevId: 868327828
1 parent e36c49f commit b3ee6a6

4 files changed

Lines changed: 128 additions & 9 deletions

File tree

common/src/main/java/dev/cel/common/internal/ProtoAdapter.java

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ public Optional<Object> adaptFieldToValue(FieldDescriptor fieldDescriptor, Objec
192192
if (bidiConverter == BidiConverter.IDENTITY) {
193193
return Optional.of(fieldValue);
194194
}
195-
return Optional.of(AdaptingTypes.adaptingList((List<?>) fieldValue, bidiConverter));
195+
ArrayList<?> convertedList =
196+
new ArrayList<>(AdaptingTypes.adaptingList((List<?>) fieldValue, bidiConverter));
197+
return Optional.of(convertedList);
196198
}
197199

198200
return Optional.of(
@@ -244,28 +246,48 @@ private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) {
244246
case SFIXED32:
245247
case SINT32:
246248
case INT32:
247-
return INT_CONVERTER;
249+
return unwrapAndConvert(INT_CONVERTER);
248250
case FIXED32:
249251
case UINT32:
250252
if (celOptions.enableUnsignedLongs()) {
251-
return UNSIGNED_UINT32_CONVERTER;
253+
return unwrapAndConvert(UNSIGNED_UINT32_CONVERTER);
252254
}
253-
return SIGNED_UINT32_CONVERTER;
255+
return unwrapAndConvert(SIGNED_UINT32_CONVERTER);
254256
case FIXED64:
255257
case UINT64:
256258
if (celOptions.enableUnsignedLongs()) {
257-
return UNSIGNED_UINT64_CONVERTER;
259+
return unwrapAndConvert(UNSIGNED_UINT64_CONVERTER);
258260
}
259-
return BidiConverter.IDENTITY;
261+
return BidiConverter.of(
262+
BidiConverter.IDENTITY.forwardConverter(),
263+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
260264
case FLOAT:
261-
return DOUBLE_CONVERTER;
265+
return unwrapAndConvert(DOUBLE_CONVERTER);
266+
case DOUBLE:
267+
case SFIXED64:
268+
case SINT64:
269+
case INT64:
270+
return BidiConverter.of(
271+
BidiConverter.IDENTITY.forwardConverter(),
272+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
262273
case BYTES:
263274
if (celOptions.evaluateCanonicalTypesToNativeValues()) {
264275
return BidiConverter.<Object, Object>of(
265-
ProtoAdapter::adaptProtoByteStringToValue, ProtoAdapter::adaptCelByteStringToProto);
276+
ProtoAdapter::adaptProtoByteStringToValue,
277+
value -> adaptCelByteStringToProto(maybeUnwrap(value)));
266278
}
267279

268-
return BidiConverter.IDENTITY;
280+
return BidiConverter.of(
281+
BidiConverter.IDENTITY.forwardConverter(),
282+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
283+
case STRING:
284+
return BidiConverter.of(
285+
BidiConverter.IDENTITY.forwardConverter(),
286+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
287+
case BOOL:
288+
return BidiConverter.of(
289+
BidiConverter.IDENTITY.forwardConverter(),
290+
value -> BidiConverter.IDENTITY.backwardConverter().convert(maybeUnwrap(value)));
269291
case ENUM:
270292
return BidiConverter.<Object, Long>of(
271293
value -> (long) ((EnumValueDescriptor) value).getNumber(),
@@ -371,4 +393,18 @@ private static int unsignedIntCheckedCast(long value) {
371393
throw new CelNumericOverflowException(e);
372394
}
373395
}
396+
397+
private Object maybeUnwrap(Object value) {
398+
if (value instanceof Message) {
399+
return adaptProtoToValue((MessageOrBuilder) value);
400+
}
401+
return value;
402+
}
403+
404+
private BidiConverter<Number, Object> unwrapAndConvert(
405+
final BidiConverter<Number, Number> original) {
406+
return BidiConverter.of(
407+
original.forwardConverter()::convert,
408+
value -> original.backwardConverter().convert((Number) maybeUnwrap(value)));
409+
}
374410
}

runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,10 @@ public void jsonFieldNames() throws Exception {
6363
// TODO: Support JSON field names for planner
6464
skipBaselineVerification();
6565
}
66+
67+
@Override
68+
public void wrappers() throws Exception {
69+
// TODO: Fix for planner
70+
skipBaselineVerification();
71+
}
6672
}

runtime/src/test/resources/wrappers.baseline

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,47 @@ declare dyn_var {
154154
bindings: {dyn_var=NULL_VALUE}
155155
result: NULL_VALUE
156156

157+
Source: TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']
158+
declare int32_list {
159+
value list(int)
160+
}
161+
declare int64_list {
162+
value list(int)
163+
}
164+
declare uint32_list {
165+
value list(uint)
166+
}
167+
declare uint64_list {
168+
value list(uint)
169+
}
170+
declare float_list {
171+
value list(double)
172+
}
173+
declare double_list {
174+
value list(double)
175+
}
176+
declare bool_list {
177+
value list(bool)
178+
}
179+
declare string_list {
180+
value list(string)
181+
}
182+
declare bytes_list {
183+
value list(bytes)
184+
}
185+
=====>
186+
bindings: {int32_list=[value: 1
187+
], int64_list=[value: 2
188+
], uint32_list=[value: 3
189+
], uint64_list=[value: 4
190+
], float_list=[value: 5.5
191+
], double_list=[value: 6.6
192+
], bool_list=[value: true
193+
], string_list=[value: "hello"
194+
], bytes_list=[value: "world"
195+
]}
196+
result: true
197+
157198
Source: google.protobuf.Timestamp{ seconds: 253402300800 }
158199
=====>
159200
bindings: {}

testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,6 +2067,42 @@ public void wrappers() throws Exception {
20672067
source = "dyn_var";
20682068
runTest(ImmutableMap.of("dyn_var", NullValue.NULL_VALUE));
20692069

2070+
clearAllDeclarations();
2071+
declareVariable("int32_list", ListType.create(SimpleType.INT));
2072+
declareVariable("int64_list", ListType.create(SimpleType.INT));
2073+
declareVariable("uint32_list", ListType.create(SimpleType.UINT));
2074+
declareVariable("uint64_list", ListType.create(SimpleType.UINT));
2075+
declareVariable("float_list", ListType.create(SimpleType.DOUBLE));
2076+
declareVariable("double_list", ListType.create(SimpleType.DOUBLE));
2077+
declareVariable("bool_list", ListType.create(SimpleType.BOOL));
2078+
declareVariable("string_list", ListType.create(SimpleType.STRING));
2079+
declareVariable("bytes_list", ListType.create(SimpleType.BYTES));
2080+
2081+
container = CelContainer.ofName(TestAllTypes.getDescriptor().getFullName());
2082+
source =
2083+
"TestAllTypes{repeated_int32: int32_list}.repeated_int32 == [1] && "
2084+
+ "TestAllTypes{repeated_int64: int64_list}.repeated_int64 == [2] && "
2085+
+ "TestAllTypes{repeated_uint32: uint32_list}.repeated_uint32 == [3u] && "
2086+
+ "TestAllTypes{repeated_uint64: uint64_list}.repeated_uint64 == [4u] && "
2087+
+ "TestAllTypes{repeated_float: float_list}.repeated_float == [5.5] && "
2088+
+ "TestAllTypes{repeated_double: double_list}.repeated_double == [6.6] && "
2089+
+ "TestAllTypes{repeated_bool: bool_list}.repeated_bool == [true] && "
2090+
+ "TestAllTypes{repeated_string: string_list}.repeated_string == ['hello'] && "
2091+
+ "TestAllTypes{repeated_bytes: bytes_list}.repeated_bytes == [b'world']";
2092+
2093+
runTest(
2094+
ImmutableMap.<String, Object>builder()
2095+
.put("int32_list", ImmutableList.of(Int32Value.of(1)))
2096+
.put("int64_list", ImmutableList.of(Int64Value.of(2)))
2097+
.put("uint32_list", ImmutableList.of(UInt32Value.of(3)))
2098+
.put("uint64_list", ImmutableList.of(UInt64Value.of(4)))
2099+
.put("float_list", ImmutableList.of(FloatValue.of(5.5f)))
2100+
.put("double_list", ImmutableList.of(DoubleValue.of(6.6)))
2101+
.put("bool_list", ImmutableList.of(BoolValue.of(true)))
2102+
.put("string_list", ImmutableList.of(StringValue.of("hello")))
2103+
.put("bytes_list", ImmutableList.of(BytesValue.of(ByteString.copyFromUtf8("world"))))
2104+
.buildOrThrow());
2105+
20702106
clearAllDeclarations();
20712107
// Currently allowed, but will be an error
20722108
// See https://github.com/google/cel-spec/pull/501

0 commit comments

Comments
 (0)