Skip to content

Commit 01df975

Browse files
[branch-53] fix: InList Dictionary filter pushdown type mismatch (apache#20962) (apache#20996)
- Part of apache#19692 - Closes apache#20996 on branch-53 This PR: - Backports apache#20962 from @erratic-pattern to the branch-53 line - Backports the related tests from apache#20960 Co-authored-by: Adam Curtis <adam.curtis.dev@gmail.com>
1 parent 5746048 commit 01df975

File tree

2 files changed

+405
-3
lines changed

2 files changed

+405
-3
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 354 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter {
9999
));
100100
}
101101

102+
// Unwrap dictionary-encoded needles when the value type matches
103+
// in_array, evaluating against the dictionary values and mapping
104+
// back via keys.
102105
downcast_dictionary_array! {
103106
v => {
104-
let values_contains = self.contains(v.values().as_ref(), negated)?;
105-
let result = take(&values_contains, v.keys(), None)?;
106-
return Ok(downcast_array(result.as_ref()))
107+
// Only unwrap when the haystack (in_array) type matches
108+
// the dictionary value type
109+
if v.values().data_type() == self.in_array.data_type() {
110+
let values_contains = self.contains(v.values().as_ref(), negated)?;
111+
let result = take(&values_contains, v.keys(), None)?;
112+
return Ok(downcast_array(result.as_ref()));
113+
}
107114
}
108115
_ => {}
109116
}
@@ -3724,4 +3731,348 @@ mod tests {
37243731
assert_eq!(result, &BooleanArray::from(vec![true, false, false]));
37253732
Ok(())
37263733
}
3734+
/// Tests that short-circuit evaluation produces correct results.
3735+
/// When all rows match after the first list item, remaining items
3736+
/// should be skipped without affecting correctness.
3737+
#[test]
3738+
fn test_in_list_with_columns_short_circuit() -> Result<()> {
3739+
// a IN (b, c) where b already matches every row of a
3740+
// The short-circuit should skip evaluating c
3741+
let schema = Schema::new(vec![
3742+
Field::new("a", DataType::Int32, false),
3743+
Field::new("b", DataType::Int32, false),
3744+
Field::new("c", DataType::Int32, false),
3745+
]);
3746+
let batch = RecordBatch::try_new(
3747+
Arc::new(schema.clone()),
3748+
vec![
3749+
Arc::new(Int32Array::from(vec![1, 2, 3])),
3750+
Arc::new(Int32Array::from(vec![1, 2, 3])), // b == a for all rows
3751+
Arc::new(Int32Array::from(vec![99, 99, 99])),
3752+
],
3753+
)?;
3754+
3755+
let col_a = col("a", &schema)?;
3756+
let list = vec![col("b", &schema)?, col("c", &schema)?];
3757+
let expr = make_in_list_with_columns(col_a, list, false);
3758+
3759+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3760+
let result = as_boolean_array(&result);
3761+
assert_eq!(result, &BooleanArray::from(vec![true, true, true]));
3762+
Ok(())
3763+
}
3764+
3765+
/// Short-circuit must NOT skip when nulls are present (three-valued logic).
3766+
/// Even if all non-null values are true, null rows keep the result as null.
3767+
#[test]
3768+
fn test_in_list_with_columns_short_circuit_with_nulls() -> Result<()> {
3769+
// a IN (b, c) where a has nulls
3770+
// Even if b matches all non-null rows, result should preserve nulls
3771+
let schema = Schema::new(vec![
3772+
Field::new("a", DataType::Int32, true),
3773+
Field::new("b", DataType::Int32, false),
3774+
Field::new("c", DataType::Int32, false),
3775+
]);
3776+
let batch = RecordBatch::try_new(
3777+
Arc::new(schema.clone()),
3778+
vec![
3779+
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])),
3780+
Arc::new(Int32Array::from(vec![1, 2, 3])), // matches non-null rows
3781+
Arc::new(Int32Array::from(vec![99, 99, 99])),
3782+
],
3783+
)?;
3784+
3785+
let col_a = col("a", &schema)?;
3786+
let list = vec![col("b", &schema)?, col("c", &schema)?];
3787+
let expr = make_in_list_with_columns(col_a, list, false);
3788+
3789+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3790+
let result = as_boolean_array(&result);
3791+
// row 0: 1 IN (1, 99) → true
3792+
// row 1: NULL IN (2, 99) → NULL
3793+
// row 2: 3 IN (3, 99) → true
3794+
assert_eq!(
3795+
result,
3796+
&BooleanArray::from(vec![Some(true), None, Some(true)])
3797+
);
3798+
Ok(())
3799+
}
3800+
3801+
/// Tests the make_comparator + collect_bool fallback path using
3802+
/// struct column references (nested types don't support arrow_eq).
3803+
#[test]
3804+
fn test_in_list_with_columns_struct() -> Result<()> {
3805+
let struct_fields = Fields::from(vec![
3806+
Field::new("x", DataType::Int32, false),
3807+
Field::new("y", DataType::Utf8, false),
3808+
]);
3809+
let struct_dt = DataType::Struct(struct_fields.clone());
3810+
3811+
let schema = Schema::new(vec![
3812+
Field::new("a", struct_dt.clone(), true),
3813+
Field::new("b", struct_dt.clone(), false),
3814+
Field::new("c", struct_dt.clone(), false),
3815+
]);
3816+
3817+
// a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}]
3818+
// b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}]
3819+
// c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}]
3820+
let a = Arc::new(StructArray::new(
3821+
struct_fields.clone(),
3822+
vec![
3823+
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
3824+
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
3825+
],
3826+
Some(vec![true, true, false, true].into()),
3827+
));
3828+
let b = Arc::new(StructArray::new(
3829+
struct_fields.clone(),
3830+
vec![
3831+
Arc::new(Int32Array::from(vec![1, 9, 3, 4])),
3832+
Arc::new(StringArray::from(vec!["a", "z", "c", "d"])),
3833+
],
3834+
None,
3835+
));
3836+
let c = Arc::new(StructArray::new(
3837+
struct_fields.clone(),
3838+
vec![
3839+
Arc::new(Int32Array::from(vec![9, 2, 9, 9])),
3840+
Arc::new(StringArray::from(vec!["z", "b", "z", "z"])),
3841+
],
3842+
None,
3843+
));
3844+
3845+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b, c])?;
3846+
3847+
let col_a = col("a", &schema)?;
3848+
let list = vec![col("b", &schema)?, col("c", &schema)?];
3849+
let expr = make_in_list_with_columns(col_a, list, false);
3850+
3851+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3852+
let result = as_boolean_array(&result);
3853+
// row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b)
3854+
// row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c)
3855+
// row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL
3856+
// row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b)
3857+
assert_eq!(
3858+
result,
3859+
&BooleanArray::from(vec![Some(true), Some(true), None, Some(true)])
3860+
);
3861+
3862+
// Also test NOT IN
3863+
let col_a = col("a", &schema)?;
3864+
let list = vec![col("b", &schema)?, col("c", &schema)?];
3865+
let expr = make_in_list_with_columns(col_a, list, true);
3866+
3867+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3868+
let result = as_boolean_array(&result);
3869+
// row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false
3870+
// row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false
3871+
// row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL
3872+
// row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false
3873+
assert_eq!(
3874+
result,
3875+
&BooleanArray::from(vec![Some(false), Some(false), None, Some(false)])
3876+
);
3877+
Ok(())
3878+
}
3879+
3880+
// -----------------------------------------------------------------------
3881+
// Tests for try_new_from_array: evaluates `needle IN in_array`.
3882+
//
3883+
// This exercises the code path used by HashJoin dynamic filter pushdown,
3884+
// where in_array is built directly from the join's build-side arrays.
3885+
// Unlike try_new (used by SQL IN expressions), which always produces a
3886+
// non-Dictionary in_array because evaluate_list() flattens Dictionary
3887+
// scalars, try_new_from_array passes the array directly and can produce
3888+
// a Dictionary in_array.
3889+
// -----------------------------------------------------------------------
3890+
3891+
fn wrap_in_dict(array: ArrayRef) -> ArrayRef {
3892+
let keys = Int32Array::from((0..array.len() as i32).collect::<Vec<_>>());
3893+
Arc::new(DictionaryArray::new(keys, array))
3894+
}
3895+
3896+
/// Evaluates `needle IN in_array` via try_new_from_array, the same
3897+
/// path used by HashJoin dynamic filter pushdown (not the SQL literal
3898+
/// IN path which goes through try_new).
3899+
fn eval_in_list_from_array(
3900+
needle: ArrayRef,
3901+
in_array: ArrayRef,
3902+
) -> Result<BooleanArray> {
3903+
let schema =
3904+
Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]);
3905+
let col_a = col("a", &schema)?;
3906+
let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?)
3907+
as Arc<dyn PhysicalExpr>;
3908+
let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?;
3909+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
3910+
Ok(as_boolean_array(&result).clone())
3911+
}
3912+
3913+
#[test]
3914+
fn test_in_list_from_array_type_combinations() -> Result<()> {
3915+
use arrow::compute::cast;
3916+
3917+
// All cases: needle[0] and needle[2] match, needle[1] does not.
3918+
let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]);
3919+
3920+
// Base arrays cast to each target type
3921+
let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef;
3922+
let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef;
3923+
3924+
// Test all specializations in instantiate_static_filter
3925+
let primitive_types = vec![
3926+
DataType::Int8,
3927+
DataType::Int16,
3928+
DataType::Int32,
3929+
DataType::Int64,
3930+
DataType::UInt8,
3931+
DataType::UInt16,
3932+
DataType::UInt32,
3933+
DataType::UInt64,
3934+
DataType::Float32,
3935+
DataType::Float64,
3936+
];
3937+
3938+
for dt in &primitive_types {
3939+
let in_array = cast(&base_in, dt)?;
3940+
let needle = cast(&base_needle, dt)?;
3941+
3942+
// T in_array, T needle
3943+
assert_eq!(
3944+
expected,
3945+
eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?,
3946+
"same-type failed for {dt:?}"
3947+
);
3948+
3949+
// T in_array, Dict(Int32, T) needle
3950+
assert_eq!(
3951+
expected,
3952+
eval_in_list_from_array(wrap_in_dict(needle), in_array)?,
3953+
"dict-needle failed for {dt:?}"
3954+
);
3955+
}
3956+
3957+
// Utf8 (falls through to ArrayStaticFilter)
3958+
let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
3959+
let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef;
3960+
3961+
// Utf8 in_array, Utf8 needle
3962+
assert_eq!(
3963+
expected,
3964+
eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)?
3965+
);
3966+
3967+
// Utf8 in_array, Dict(Utf8) needle
3968+
assert_eq!(
3969+
expected,
3970+
eval_in_list_from_array(
3971+
wrap_in_dict(Arc::clone(&utf8_needle)),
3972+
Arc::clone(&utf8_in),
3973+
)?
3974+
);
3975+
3976+
// Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
3977+
assert_eq!(
3978+
expected,
3979+
eval_in_list_from_array(
3980+
wrap_in_dict(Arc::clone(&utf8_needle)),
3981+
wrap_in_dict(Arc::clone(&utf8_in)),
3982+
)?
3983+
);
3984+
3985+
// Struct in_array, Struct needle: multi-column join
3986+
let struct_fields = Fields::from(vec![
3987+
Field::new("c0", DataType::Utf8, true),
3988+
Field::new("c1", DataType::Int64, true),
3989+
]);
3990+
let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
3991+
let pairs: Vec<(FieldRef, ArrayRef)> =
3992+
struct_fields.iter().cloned().zip([c0, c1]).collect();
3993+
Arc::new(StructArray::from(pairs))
3994+
};
3995+
assert_eq!(
3996+
expected,
3997+
eval_in_list_from_array(
3998+
make_struct(
3999+
Arc::clone(&utf8_needle),
4000+
Arc::new(Int64Array::from(vec![1, 4, 2])),
4001+
),
4002+
make_struct(
4003+
Arc::clone(&utf8_in),
4004+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4005+
),
4006+
)?
4007+
);
4008+
4009+
// Struct with Dict fields: multi-column Dict join
4010+
let dict_struct_fields = Fields::from(vec![
4011+
Field::new(
4012+
"c0",
4013+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
4014+
true,
4015+
),
4016+
Field::new("c1", DataType::Int64, true),
4017+
]);
4018+
let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef {
4019+
let pairs: Vec<(FieldRef, ArrayRef)> =
4020+
dict_struct_fields.iter().cloned().zip([c0, c1]).collect();
4021+
Arc::new(StructArray::from(pairs))
4022+
};
4023+
assert_eq!(
4024+
expected,
4025+
eval_in_list_from_array(
4026+
make_dict_struct(
4027+
wrap_in_dict(Arc::clone(&utf8_needle)),
4028+
Arc::new(Int64Array::from(vec![1, 4, 2])),
4029+
),
4030+
make_dict_struct(
4031+
wrap_in_dict(Arc::clone(&utf8_in)),
4032+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4033+
),
4034+
)?
4035+
);
4036+
4037+
Ok(())
4038+
}
4039+
4040+
#[test]
4041+
fn test_in_list_from_array_type_mismatch_errors() -> Result<()> {
4042+
// Utf8 needle, Dict(Utf8) in_array
4043+
let err = eval_in_list_from_array(
4044+
Arc::new(StringArray::from(vec!["a", "d", "b"])),
4045+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
4046+
)
4047+
.unwrap_err()
4048+
.to_string();
4049+
assert!(
4050+
err.contains("Can't compare arrays of different types"),
4051+
"{err}"
4052+
);
4053+
4054+
// Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
4055+
// rejects the Utf8 dictionary values at construction time
4056+
let err = eval_in_list_from_array(
4057+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))),
4058+
Arc::new(Int64Array::from(vec![1, 2, 3])),
4059+
)
4060+
.unwrap_err()
4061+
.to_string();
4062+
assert!(err.contains("Failed to downcast"), "{err}");
4063+
4064+
// Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
4065+
// value types, make_comparator rejects the comparison
4066+
let err = eval_in_list_from_array(
4067+
wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))),
4068+
wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))),
4069+
)
4070+
.unwrap_err()
4071+
.to_string();
4072+
assert!(
4073+
err.contains("Can't compare arrays of different types"),
4074+
"{err}"
4075+
);
4076+
Ok(())
4077+
}
37274078
}

0 commit comments

Comments
 (0)