@@ -26,8 +26,9 @@ use datafusion_expr::test::function_stub::{
2626 count_udaf, max, max_udaf, min_udaf, sum, sum_udaf,
2727} ;
2828use datafusion_expr:: {
29- EmptyRelation , Expr , Extension , LogicalPlan , LogicalPlanBuilder , Union ,
30- UserDefinedLogicalNode , UserDefinedLogicalNodeCore , WindowFrame ,
29+ ColumnarValue , EmptyRelation , Expr , Extension , LogicalPlan , LogicalPlanBuilder ,
30+ ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl , Signature , Union ,
31+ UserDefinedLogicalNode , UserDefinedLogicalNodeCore , Volatility , WindowFrame ,
3132 WindowFunctionDefinition , cast, col, lit, table_scan, wildcard,
3233} ;
3334use datafusion_functions:: unicode;
@@ -3032,7 +3033,7 @@ fn snowflake_unnest_to_lateral_flatten_cross_join_inline() -> Result<(), DataFus
30323033 sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id" ,
30333034 parser_dialect: GenericDialect { } ,
30343035 unparser_dialect: snowflake,
3035- expected: @r#"SELECT "u"."c1", "j1"."j1_id", "j1"."j1_string" FROM (SELECT "_unnest"."VALUE" AS "c1" FROM LATERAL FLATTEN(INPUT => [1, 2, 3]) AS _unnest ) AS "u" INNER JOIN "j1" ON ("u"."c1" = "j1"."j1_id")"# ,
3036+ expected: @r#"SELECT "u"."c1", "j1"."j1_id", "j1"."j1_string" FROM LATERAL FLATTEN(INPUT => [1, 2, 3]) AS "u" INNER JOIN "j1" ON ("u"."c1" = "j1"."j1_id")"# ,
30363037 ) ;
30373038 Ok ( ( ) )
30383039}
@@ -3072,7 +3073,7 @@ fn snowflake_flatten_select_unnest_with_alias() -> Result<(), DataFusionError> {
30723073 sql: "SELECT UNNEST([1,2,3]) as c1" ,
30733074 parser_dialect: GenericDialect { } ,
30743075 unparser_dialect: snowflake,
3075- expected: @r#"SELECT " _unnest" ."VALUE" AS "c1 " FROM LATERAL FLATTEN(INPUT => [1, 2, 3]) AS _unnest"# ,
3076+ expected: @r#"SELECT _unnest."VALUE" FROM LATERAL FLATTEN(INPUT => [1, 2, 3]) AS _unnest"# ,
30763077 ) ;
30773078 Ok ( ( ) )
30783079}
@@ -3096,7 +3097,89 @@ fn snowflake_flatten_from_unnest_with_table_alias() -> Result<(), DataFusionErro
30963097 sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)" ,
30973098 parser_dialect: GenericDialect { } ,
30983099 unparser_dialect: snowflake,
3099- expected: @r#"SELECT "t1"."c1" FROM (SELECT "_unnest"."VALUE" AS "c1" FROM LATERAL FLATTEN(INPUT => [1, 2, 3]) AS _unnest) AS "t1""# ,
3100+ expected: @r#"SELECT "t1"."c1" FROM LATERAL FLATTEN(INPUT => [1, 2, 3]) AS "t1""# ,
3101+ ) ;
3102+ Ok ( ( ) )
3103+ }
3104+
3105+ #[ test]
3106+ fn snowflake_flatten_unnest_from_subselect ( ) -> Result < ( ) , DataFusionError > {
3107+ // UNNEST operating on an array column produced by a subselect.
3108+ // Uses unnest_table which has array_col (List<Int64>).
3109+ // The filter uses array_col IS NOT NULL — a simple predicate
3110+ // that doesn't involve struct types (which Snowflake FLATTEN can't handle).
3111+ let snowflake = SnowflakeDialect :: new ( ) ;
3112+ roundtrip_statement_with_dialect_helper ! (
3113+ sql: "SELECT UNNEST(array_col) FROM (SELECT array_col FROM unnest_table WHERE array_col IS NOT NULL LIMIT 3)" ,
3114+ parser_dialect: GenericDialect { } ,
3115+ unparser_dialect: snowflake,
3116+ expected: @r#"SELECT _unnest."VALUE" FROM (SELECT "unnest_table"."array_col" FROM "unnest_table" WHERE "unnest_table"."array_col" IS NOT NULL LIMIT 3) CROSS JOIN LATERAL FLATTEN(INPUT => "unnest_table"."array_col") AS _unnest"# ,
31003117 ) ;
31013118 Ok ( ( ) )
31023119}
3120+
3121+ /// Dummy scalar UDF for testing — takes a string and returns List<Int64>.
3122+ #[ derive( Debug , PartialEq , Eq , Hash ) ]
3123+ struct JsonGetArrayUdf {
3124+ signature : Signature ,
3125+ }
3126+
3127+ impl JsonGetArrayUdf {
3128+ fn new ( ) -> Self {
3129+ Self {
3130+ signature : Signature :: exact ( vec ! [ DataType :: Utf8 ] , Volatility :: Immutable ) ,
3131+ }
3132+ }
3133+ }
3134+
3135+ impl ScalarUDFImpl for JsonGetArrayUdf {
3136+ fn name ( & self ) -> & str {
3137+ "json_get_array"
3138+ }
3139+ fn signature ( & self ) -> & Signature {
3140+ & self . signature
3141+ }
3142+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
3143+ Ok ( DataType :: List ( Arc :: new ( Field :: new_list_field (
3144+ DataType :: Int64 ,
3145+ true ,
3146+ ) ) ) )
3147+ }
3148+ fn invoke_with_args ( & self , _args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
3149+ unimplemented ! ( "test stub" )
3150+ }
3151+ }
3152+
3153+ #[ test]
3154+ fn snowflake_flatten_unnest_udf_result ( ) -> Result < ( ) , DataFusionError > {
3155+ // UNNEST on a UDF result: json_get_array(col) returns List<Int64>,
3156+ // then UNNEST flattens it. This simulates a common Snowflake pattern
3157+ // where a UDF parses JSON into an array, then FLATTEN expands it.
3158+ let sql = "SELECT UNNEST(json_get_array(j1_string)) AS items FROM j1 LIMIT 5" ;
3159+
3160+ let statement = Parser :: new ( & GenericDialect { } )
3161+ . try_with_sql ( sql) ?
3162+ . parse_statement ( ) ?;
3163+
3164+ let state = MockSessionState :: default ( )
3165+ . with_aggregate_function ( max_udaf ( ) )
3166+ . with_aggregate_function ( min_udaf ( ) )
3167+ . with_scalar_function ( Arc :: new ( ScalarUDF :: new_from_impl ( JsonGetArrayUdf :: new ( ) ) ) )
3168+ . with_expr_planner ( Arc :: new ( CoreFunctionPlanner :: default ( ) ) )
3169+ . with_expr_planner ( Arc :: new ( NestedFunctionPlanner ) )
3170+ . with_expr_planner ( Arc :: new ( FieldAccessPlanner ) ) ;
3171+
3172+ let context = MockContextProvider { state } ;
3173+ let sql_to_rel = SqlToRel :: new ( & context) ;
3174+ let plan = sql_to_rel
3175+ . sql_statement_to_plan ( statement)
3176+ . unwrap_or_else ( |e| panic ! ( "Failed to parse sql: {sql}\n {e}" ) ) ;
3177+
3178+ let snowflake = SnowflakeDialect :: new ( ) ;
3179+ let unparser = Unparser :: new ( & snowflake) ;
3180+ let result = unparser. plan_to_sql ( & plan) ?;
3181+ let actual = result. to_string ( ) ;
3182+
3183+ insta:: assert_snapshot!( actual, @r#"SELECT _unnest."VALUE" FROM "j1" CROSS JOIN LATERAL FLATTEN(INPUT => json_get_array("j1"."j1_string")) AS _unnest LIMIT 5"# ) ;
3184+ Ok ( ( ) )
3185+ }
0 commit comments