@@ -925,66 +925,51 @@ def test_window_frame_defaults_match_postgres(partitioned_df):
925925 assert df_2 .sort (col_a ).to_pydict () == expected
926926
927927
928- def _build_last_value_df (df ):
929- return df .select (
930- f .last_value (column ("a" ))
931- .over (
932- Window (
933- partition_by = [column ("c" )],
934- order_by = [column ("b" )],
935- window_frame = WindowFrame ("rows" , None , None ),
936- )
928+ def test_last_value_order_by_string_equivalence (partitioned_df ):
929+ expr = f .last_value (column ("a" )).over (
930+ Window (
931+ partition_by = [column ("c" )],
932+ order_by = [column ("b" )],
933+ window_frame = WindowFrame ("rows" , None , None ),
937934 )
938- .alias ("expr" ),
939- f .last_value (column ("a" ))
940- .over (
941- Window (
942- partition_by = [column ("c" )],
943- order_by = "b" ,
944- window_frame = WindowFrame ("rows" , None , None ),
945- )
935+ )
936+ string = f .last_value (column ("a" )).over (
937+ Window (
938+ partition_by = [column ("c" )],
939+ order_by = "b" ,
940+ window_frame = WindowFrame ("rows" , None , None ),
946941 )
947- .alias ("str" ),
948942 )
943+ df = partitioned_df .select (expr .alias ("expr" ), string .alias ("str" ))
944+ table = pa .Table .from_batches (df .collect ())
945+ assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
949946
950947
951- def _build_nth_value_df (df ):
952- return df .select (
953- f .nth_value (column ("b" ), 3 ).over (Window (order_by = [column ("a" )])).alias ("expr" ),
954- f .nth_value (column ("b" ), 3 ).over (Window (order_by = "a" )).alias ("str" ),
955- )
948+ def test_nth_value_order_by_string_equivalence (partitioned_df ):
949+ expr = f .nth_value (column ("b" ), 3 ).over (Window (order_by = [column ("a" )]))
950+ string = f .nth_value (column ("b" ), 3 ).over (Window (order_by = "a" ))
951+ df = partitioned_df .select (expr .alias ("expr" ), string .alias ("str" ))
952+ table = pa .Table .from_batches (df .collect ())
953+ assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
956954
957955
958- def _build_rank_df (df ):
959- return df .select (
960- f .rank (order_by = [column ("b" )]).alias ("expr" ),
961- f .rank (order_by = "b" ).alias ("str" ),
962- )
956+ def test_rank_order_by_string_equivalence (partitioned_df ):
957+ expr = f .rank (order_by = [column ("b" )])
958+ string = f .rank (order_by = "b" )
959+ df = partitioned_df .select (expr .alias ("expr" ), string .alias ("str" ))
960+ table = pa .Table .from_batches (df .collect ())
961+ assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
963962
964963
965- def _build_array_agg_df ( df ):
966- return df .aggregate (
964+ def test_array_agg_order_by_string_equivalence ( partitioned_df ):
965+ df = partitioned_df .aggregate (
967966 [column ("c" )],
968967 [
969968 f .array_agg (column ("a" ), order_by = [column ("a" )]).alias ("expr" ),
970969 f .array_agg (column ("a" ), order_by = "a" ).alias ("str" ),
971970 ],
972971 ).sort (column ("c" ))
973-
974-
975- @pytest .mark .parametrize (
976- ("builder" , "expected" ),
977- [
978- pytest .param (_build_last_value_df , [3 , 3 , 3 , 3 , 6 , 6 , 6 ], id = "last_value" ),
979- pytest .param (_build_nth_value_df , [None , None , 7 , 7 , 7 , 7 , 7 ], id = "nth_value" ),
980- pytest .param (_build_rank_df , [3 , 1 , 3 , 5 , 6 , 1 , 6 ], id = "rank" ),
981- pytest .param (_build_array_agg_df , [[0 , 1 , 2 , 3 ], [4 , 5 , 6 ]], id = "array_agg" ),
982- ],
983- )
984- def test_order_by_string_equivalence (partitioned_df , builder , expected ):
985- df = builder (partitioned_df )
986972 table = pa .Table .from_batches (df .collect ())
987- assert table .column ("expr" ).to_pylist () == expected
988973 assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
989974
990975
0 commit comments