@@ -925,51 +925,66 @@ def test_window_frame_defaults_match_postgres(partitioned_df):
925925 assert df_2 .sort (col_a ).to_pydict () == expected
926926
927927
928- def test_last_value_order_by_string_equivalence (partitioned_df ):
929- expr = f .last_value (column ("a" )).over (
930- Window (
931- partition_by = [column ("c" )],
932- order_by = [column ("b" )],
933- window_frame = WindowFrame ("rows" , None , None ),
928+ def _build_last_value_df (df ):
929+ return df .select (
930+ f .last_value (column ("a" ))
931+ .over (
932+ Window (
933+ partition_by = [column ("c" )],
934+ order_by = [column ("b" )],
935+ window_frame = WindowFrame ("rows" , None , None ),
936+ )
934937 )
935- )
936- string = f .last_value (column ("a" )).over (
937- Window (
938- partition_by = [column ("c" )],
939- order_by = "b" ,
940- window_frame = WindowFrame ("rows" , None , None ),
938+ .alias ("expr" ),
939+ f .last_value (column ("a" ))
940+ .over (
941+ Window (
942+ partition_by = [column ("c" )],
943+ order_by = "b" ,
944+ window_frame = WindowFrame ("rows" , None , None ),
945+ )
941946 )
947+ .alias ("str" ),
942948 )
943- df = partitioned_df .select (expr .alias ("expr" ), string .alias ("str" ))
944- table = pa .Table .from_batches (df .collect ())
945- assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
946949
947950
948- def test_nth_value_order_by_string_equivalence (partitioned_df ):
949- expr = f .nth_value (column ("b" ), 3 ).over (Window (order_by = [column ("a" )]))
950- string = f .nth_value (column ("b" ), 3 ).over (Window (order_by = "a" ))
951- df = partitioned_df .select (expr .alias ("expr" ), string .alias ("str" ))
952- table = pa .Table .from_batches (df .collect ())
953- assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
951+ def _build_nth_value_df (df ):
952+ return df .select (
953+ f .nth_value (column ("b" ), 3 ).over (Window (order_by = [column ("a" )])).alias ("expr" ),
954+ f .nth_value (column ("b" ), 3 ).over (Window (order_by = "a" )).alias ("str" ),
955+ )
954956
955957
956- def test_rank_order_by_string_equivalence (partitioned_df ):
957- expr = f .rank (order_by = [column ("b" )])
958- string = f .rank (order_by = "b" )
959- df = partitioned_df .select (expr .alias ("expr" ), string .alias ("str" ))
960- table = pa .Table .from_batches (df .collect ())
961- assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
958+ def _build_rank_df (df ):
959+ return df .select (
960+ f .rank (order_by = [column ("b" )]).alias ("expr" ),
961+ f .rank (order_by = "b" ).alias ("str" ),
962+ )
962963
963964
964- def test_array_agg_order_by_string_equivalence ( partitioned_df ):
965- df = partitioned_df .aggregate (
965+ def _build_array_agg_df ( df ):
966+ return df .aggregate (
966967 [column ("c" )],
967968 [
968969 f .array_agg (column ("a" ), order_by = [column ("a" )]).alias ("expr" ),
969970 f .array_agg (column ("a" ), order_by = "a" ).alias ("str" ),
970971 ],
971972 ).sort (column ("c" ))
973+
974+
975+ @pytest .mark .parametrize (
976+ ("builder" , "expected" ),
977+ [
978+ pytest .param (_build_last_value_df , [3 , 3 , 3 , 3 , 6 , 6 , 6 ], id = "last_value" ),
979+ pytest .param (_build_nth_value_df , [None , None , 7 , 7 , 7 , 7 , 7 ], id = "nth_value" ),
980+ pytest .param (_build_rank_df , [1 , 1 , 3 , 3 , 5 , 6 , 6 ], id = "rank" ),
981+ pytest .param (_build_array_agg_df , [[0 , 1 , 2 , 3 ], [4 , 5 , 6 ]], id = "array_agg" ),
982+ ],
983+ )
984+ def test_order_by_string_equivalence (partitioned_df , builder , expected ):
985+ df = builder (partitioned_df )
972986 table = pa .Table .from_batches (df .collect ())
987+ assert table .column ("expr" ).to_pylist () == expected
973988 assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
974989
975990
0 commit comments