@@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter {
9999 ) ) ;
100100 }
101101
102+ // Unwrap dictionary-encoded needles when the value type matches
103+ // in_array, evaluating against the dictionary values and mapping
104+ // back via keys.
102105 downcast_dictionary_array ! {
103106 v => {
104- let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
105- let result = take( & values_contains, v. keys( ) , None ) ?;
106- return Ok ( downcast_array( result. as_ref( ) ) )
107+ // Only unwrap when the haystack (in_array) type matches
108+ // the dictionary value type
109+ if v. values( ) . data_type( ) == self . in_array. data_type( ) {
110+ let values_contains = self . contains( v. values( ) . as_ref( ) , negated) ?;
111+ let result = take( & values_contains, v. keys( ) , None ) ?;
112+ return Ok ( downcast_array( result. as_ref( ) ) ) ;
113+ }
107114 }
108115 _ => { }
109116 }
@@ -3724,4 +3731,348 @@ mod tests {
37243731 assert_eq ! ( result, & BooleanArray :: from( vec![ true , false , false ] ) ) ;
37253732 Ok ( ( ) )
37263733 }
3734+ /// Tests that short-circuit evaluation produces correct results.
3735+ /// When all rows match after the first list item, remaining items
3736+ /// should be skipped without affecting correctness.
3737+ #[ test]
3738+ fn test_in_list_with_columns_short_circuit ( ) -> Result < ( ) > {
3739+ // a IN (b, c) where b already matches every row of a
3740+ // The short-circuit should skip evaluating c
3741+ let schema = Schema :: new ( vec ! [
3742+ Field :: new( "a" , DataType :: Int32 , false ) ,
3743+ Field :: new( "b" , DataType :: Int32 , false ) ,
3744+ Field :: new( "c" , DataType :: Int32 , false ) ,
3745+ ] ) ;
3746+ let batch = RecordBatch :: try_new (
3747+ Arc :: new ( schema. clone ( ) ) ,
3748+ vec ! [
3749+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
3750+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) , // b == a for all rows
3751+ Arc :: new( Int32Array :: from( vec![ 99 , 99 , 99 ] ) ) ,
3752+ ] ,
3753+ ) ?;
3754+
3755+ let col_a = col ( "a" , & schema) ?;
3756+ let list = vec ! [ col( "b" , & schema) ?, col( "c" , & schema) ?] ;
3757+ let expr = make_in_list_with_columns ( col_a, list, false ) ;
3758+
3759+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3760+ let result = as_boolean_array ( & result) ;
3761+ assert_eq ! ( result, & BooleanArray :: from( vec![ true , true , true ] ) ) ;
3762+ Ok ( ( ) )
3763+ }
3764+
3765+ /// Short-circuit must NOT skip when nulls are present (three-valued logic).
3766+ /// Even if all non-null values are true, null rows keep the result as null.
3767+ #[ test]
3768+ fn test_in_list_with_columns_short_circuit_with_nulls ( ) -> Result < ( ) > {
3769+ // a IN (b, c) where a has nulls
3770+ // Even if b matches all non-null rows, result should preserve nulls
3771+ let schema = Schema :: new ( vec ! [
3772+ Field :: new( "a" , DataType :: Int32 , true ) ,
3773+ Field :: new( "b" , DataType :: Int32 , false ) ,
3774+ Field :: new( "c" , DataType :: Int32 , false ) ,
3775+ ] ) ;
3776+ let batch = RecordBatch :: try_new (
3777+ Arc :: new ( schema. clone ( ) ) ,
3778+ vec ! [
3779+ Arc :: new( Int32Array :: from( vec![ Some ( 1 ) , None , Some ( 3 ) ] ) ) ,
3780+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) , // matches non-null rows
3781+ Arc :: new( Int32Array :: from( vec![ 99 , 99 , 99 ] ) ) ,
3782+ ] ,
3783+ ) ?;
3784+
3785+ let col_a = col ( "a" , & schema) ?;
3786+ let list = vec ! [ col( "b" , & schema) ?, col( "c" , & schema) ?] ;
3787+ let expr = make_in_list_with_columns ( col_a, list, false ) ;
3788+
3789+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3790+ let result = as_boolean_array ( & result) ;
3791+ // row 0: 1 IN (1, 99) → true
3792+ // row 1: NULL IN (2, 99) → NULL
3793+ // row 2: 3 IN (3, 99) → true
3794+ assert_eq ! (
3795+ result,
3796+ & BooleanArray :: from( vec![ Some ( true ) , None , Some ( true ) ] )
3797+ ) ;
3798+ Ok ( ( ) )
3799+ }
3800+
3801+ /// Tests the make_comparator + collect_bool fallback path using
3802+ /// struct column references (nested types don't support arrow_eq).
3803+ #[ test]
3804+ fn test_in_list_with_columns_struct ( ) -> Result < ( ) > {
3805+ let struct_fields = Fields :: from ( vec ! [
3806+ Field :: new( "x" , DataType :: Int32 , false ) ,
3807+ Field :: new( "y" , DataType :: Utf8 , false ) ,
3808+ ] ) ;
3809+ let struct_dt = DataType :: Struct ( struct_fields. clone ( ) ) ;
3810+
3811+ let schema = Schema :: new ( vec ! [
3812+ Field :: new( "a" , struct_dt. clone( ) , true ) ,
3813+ Field :: new( "b" , struct_dt. clone( ) , false ) ,
3814+ Field :: new( "c" , struct_dt. clone( ) , false ) ,
3815+ ] ) ;
3816+
3817+ // a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}]
3818+ // b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}]
3819+ // c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}]
3820+ let a = Arc :: new ( StructArray :: new (
3821+ struct_fields. clone ( ) ,
3822+ vec ! [
3823+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 , 4 ] ) ) ,
3824+ Arc :: new( StringArray :: from( vec![ "a" , "b" , "c" , "d" ] ) ) ,
3825+ ] ,
3826+ Some ( vec ! [ true , true , false , true ] . into ( ) ) ,
3827+ ) ) ;
3828+ let b = Arc :: new ( StructArray :: new (
3829+ struct_fields. clone ( ) ,
3830+ vec ! [
3831+ Arc :: new( Int32Array :: from( vec![ 1 , 9 , 3 , 4 ] ) ) ,
3832+ Arc :: new( StringArray :: from( vec![ "a" , "z" , "c" , "d" ] ) ) ,
3833+ ] ,
3834+ None ,
3835+ ) ) ;
3836+ let c = Arc :: new ( StructArray :: new (
3837+ struct_fields. clone ( ) ,
3838+ vec ! [
3839+ Arc :: new( Int32Array :: from( vec![ 9 , 2 , 9 , 9 ] ) ) ,
3840+ Arc :: new( StringArray :: from( vec![ "z" , "b" , "z" , "z" ] ) ) ,
3841+ ] ,
3842+ None ,
3843+ ) ) ;
3844+
3845+ let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ a, b, c] ) ?;
3846+
3847+ let col_a = col ( "a" , & schema) ?;
3848+ let list = vec ! [ col( "b" , & schema) ?, col( "c" , & schema) ?] ;
3849+ let expr = make_in_list_with_columns ( col_a, list, false ) ;
3850+
3851+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3852+ let result = as_boolean_array ( & result) ;
3853+ // row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b)
3854+ // row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c)
3855+ // row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL
3856+ // row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b)
3857+ assert_eq ! (
3858+ result,
3859+ & BooleanArray :: from( vec![ Some ( true ) , Some ( true ) , None , Some ( true ) ] )
3860+ ) ;
3861+
3862+ // Also test NOT IN
3863+ let col_a = col ( "a" , & schema) ?;
3864+ let list = vec ! [ col( "b" , & schema) ?, col( "c" , & schema) ?] ;
3865+ let expr = make_in_list_with_columns ( col_a, list, true ) ;
3866+
3867+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3868+ let result = as_boolean_array ( & result) ;
3869+ // row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false
3870+ // row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false
3871+ // row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL
3872+ // row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false
3873+ assert_eq ! (
3874+ result,
3875+ & BooleanArray :: from( vec![ Some ( false ) , Some ( false ) , None , Some ( false ) ] )
3876+ ) ;
3877+ Ok ( ( ) )
3878+ }
3879+
3880+ // -----------------------------------------------------------------------
3881+ // Tests for try_new_from_array: evaluates `needle IN in_array`.
3882+ //
3883+ // This exercises the code path used by HashJoin dynamic filter pushdown,
3884+ // where in_array is built directly from the join's build-side arrays.
3885+ // Unlike try_new (used by SQL IN expressions), which always produces a
3886+ // non-Dictionary in_array because evaluate_list() flattens Dictionary
3887+ // scalars, try_new_from_array passes the array directly and can produce
3888+ // a Dictionary in_array.
3889+ // -----------------------------------------------------------------------
3890+
3891+ fn wrap_in_dict ( array : ArrayRef ) -> ArrayRef {
3892+ let keys = Int32Array :: from ( ( 0 ..array. len ( ) as i32 ) . collect :: < Vec < _ > > ( ) ) ;
3893+ Arc :: new ( DictionaryArray :: new ( keys, array) )
3894+ }
3895+
3896+ /// Evaluates `needle IN in_array` via try_new_from_array, the same
3897+ /// path used by HashJoin dynamic filter pushdown (not the SQL literal
3898+ /// IN path which goes through try_new).
3899+ fn eval_in_list_from_array (
3900+ needle : ArrayRef ,
3901+ in_array : ArrayRef ,
3902+ ) -> Result < BooleanArray > {
3903+ let schema =
3904+ Schema :: new ( vec ! [ Field :: new( "a" , needle. data_type( ) . clone( ) , false ) ] ) ;
3905+ let col_a = col ( "a" , & schema) ?;
3906+ let expr = Arc :: new ( InListExpr :: try_new_from_array ( col_a, in_array, false ) ?)
3907+ as Arc < dyn PhysicalExpr > ;
3908+ let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ needle] ) ?;
3909+ let result = expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ?;
3910+ Ok ( as_boolean_array ( & result) . clone ( ) )
3911+ }
3912+
3913+ #[ test]
3914+ fn test_in_list_from_array_type_combinations ( ) -> Result < ( ) > {
3915+ use arrow:: compute:: cast;
3916+
3917+ // All cases: needle[0] and needle[2] match, needle[1] does not.
3918+ let expected = BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) , Some ( true ) ] ) ;
3919+
3920+ // Base arrays cast to each target type
3921+ let base_in = Arc :: new ( Int64Array :: from ( vec ! [ 1i64 , 2 , 3 ] ) ) as ArrayRef ;
3922+ let base_needle = Arc :: new ( Int64Array :: from ( vec ! [ 1i64 , 4 , 2 ] ) ) as ArrayRef ;
3923+
3924+ // Test all specializations in instantiate_static_filter
3925+ let primitive_types = vec ! [
3926+ DataType :: Int8 ,
3927+ DataType :: Int16 ,
3928+ DataType :: Int32 ,
3929+ DataType :: Int64 ,
3930+ DataType :: UInt8 ,
3931+ DataType :: UInt16 ,
3932+ DataType :: UInt32 ,
3933+ DataType :: UInt64 ,
3934+ DataType :: Float32 ,
3935+ DataType :: Float64 ,
3936+ ] ;
3937+
3938+ for dt in & primitive_types {
3939+ let in_array = cast ( & base_in, dt) ?;
3940+ let needle = cast ( & base_needle, dt) ?;
3941+
3942+ // T in_array, T needle
3943+ assert_eq ! (
3944+ expected,
3945+ eval_in_list_from_array( Arc :: clone( & needle) , Arc :: clone( & in_array) ) ?,
3946+ "same-type failed for {dt:?}"
3947+ ) ;
3948+
3949+ // T in_array, Dict(Int32, T) needle
3950+ assert_eq ! (
3951+ expected,
3952+ eval_in_list_from_array( wrap_in_dict( needle) , in_array) ?,
3953+ "dict-needle failed for {dt:?}"
3954+ ) ;
3955+ }
3956+
3957+ // Utf8 (falls through to ArrayStaticFilter)
3958+ let utf8_in = Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) as ArrayRef ;
3959+ let utf8_needle = Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) as ArrayRef ;
3960+
3961+ // Utf8 in_array, Utf8 needle
3962+ assert_eq ! (
3963+ expected,
3964+ eval_in_list_from_array( Arc :: clone( & utf8_needle) , Arc :: clone( & utf8_in) , ) ?
3965+ ) ;
3966+
3967+ // Utf8 in_array, Dict(Utf8) needle
3968+ assert_eq ! (
3969+ expected,
3970+ eval_in_list_from_array(
3971+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3972+ Arc :: clone( & utf8_in) ,
3973+ ) ?
3974+ ) ;
3975+
3976+ // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug
3977+ assert_eq ! (
3978+ expected,
3979+ eval_in_list_from_array(
3980+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
3981+ wrap_in_dict( Arc :: clone( & utf8_in) ) ,
3982+ ) ?
3983+ ) ;
3984+
3985+ // Struct in_array, Struct needle: multi-column join
3986+ let struct_fields = Fields :: from ( vec ! [
3987+ Field :: new( "c0" , DataType :: Utf8 , true ) ,
3988+ Field :: new( "c1" , DataType :: Int64 , true ) ,
3989+ ] ) ;
3990+ let make_struct = |c0 : ArrayRef , c1 : ArrayRef | -> ArrayRef {
3991+ let pairs: Vec < ( FieldRef , ArrayRef ) > =
3992+ struct_fields. iter ( ) . cloned ( ) . zip ( [ c0, c1] ) . collect ( ) ;
3993+ Arc :: new ( StructArray :: from ( pairs) )
3994+ } ;
3995+ assert_eq ! (
3996+ expected,
3997+ eval_in_list_from_array(
3998+ make_struct(
3999+ Arc :: clone( & utf8_needle) ,
4000+ Arc :: new( Int64Array :: from( vec![ 1 , 4 , 2 ] ) ) ,
4001+ ) ,
4002+ make_struct(
4003+ Arc :: clone( & utf8_in) ,
4004+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
4005+ ) ,
4006+ ) ?
4007+ ) ;
4008+
4009+ // Struct with Dict fields: multi-column Dict join
4010+ let dict_struct_fields = Fields :: from ( vec ! [
4011+ Field :: new(
4012+ "c0" ,
4013+ DataType :: Dictionary ( Box :: new( DataType :: Int32 ) , Box :: new( DataType :: Utf8 ) ) ,
4014+ true ,
4015+ ) ,
4016+ Field :: new( "c1" , DataType :: Int64 , true ) ,
4017+ ] ) ;
4018+ let make_dict_struct = |c0 : ArrayRef , c1 : ArrayRef | -> ArrayRef {
4019+ let pairs: Vec < ( FieldRef , ArrayRef ) > =
4020+ dict_struct_fields. iter ( ) . cloned ( ) . zip ( [ c0, c1] ) . collect ( ) ;
4021+ Arc :: new ( StructArray :: from ( pairs) )
4022+ } ;
4023+ assert_eq ! (
4024+ expected,
4025+ eval_in_list_from_array(
4026+ make_dict_struct(
4027+ wrap_in_dict( Arc :: clone( & utf8_needle) ) ,
4028+ Arc :: new( Int64Array :: from( vec![ 1 , 4 , 2 ] ) ) ,
4029+ ) ,
4030+ make_dict_struct(
4031+ wrap_in_dict( Arc :: clone( & utf8_in) ) ,
4032+ Arc :: new( Int64Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
4033+ ) ,
4034+ ) ?
4035+ ) ;
4036+
4037+ Ok ( ( ) )
4038+ }
4039+
4040+ #[ test]
4041+ fn test_in_list_from_array_type_mismatch_errors ( ) -> Result < ( ) > {
4042+ // Utf8 needle, Dict(Utf8) in_array
4043+ let err = eval_in_list_from_array (
4044+ Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) ,
4045+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) ) ,
4046+ )
4047+ . unwrap_err ( )
4048+ . to_string ( ) ;
4049+ assert ! (
4050+ err. contains( "Can't compare arrays of different types" ) ,
4051+ "{err}"
4052+ ) ;
4053+
4054+ // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter
4055+ // rejects the Utf8 dictionary values at construction time
4056+ let err = eval_in_list_from_array (
4057+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "d" , "b" ] ) ) ) ,
4058+ Arc :: new ( Int64Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) ,
4059+ )
4060+ . unwrap_err ( )
4061+ . to_string ( ) ;
4062+ assert ! ( err. contains( "Failed to downcast" ) , "{err}" ) ;
4063+
4064+ // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different
4065+ // value types, make_comparator rejects the comparison
4066+ let err = eval_in_list_from_array (
4067+ wrap_in_dict ( Arc :: new ( Int64Array :: from ( vec ! [ 1 , 4 , 2 ] ) ) ) ,
4068+ wrap_in_dict ( Arc :: new ( StringArray :: from ( vec ! [ "a" , "b" , "c" ] ) ) ) ,
4069+ )
4070+ . unwrap_err ( )
4071+ . to_string ( ) ;
4072+ assert ! (
4073+ err. contains( "Can't compare arrays of different types" ) ,
4074+ "{err}"
4075+ ) ;
4076+ Ok ( ( ) )
4077+ }
37274078}
0 commit comments