@@ -303,6 +303,18 @@ def test_aggregate_string_and_expression_equivalent(df):
303303 assert result_str == result_expr
304304
305305
306+ def test_aggregate_tuple_group_by (df ):
307+ result_list = df .aggregate (["a" ], [f .count ()]).sort ("a" ).to_pydict ()
308+ result_tuple = df .aggregate (("a" ,), [f .count ()]).sort ("a" ).to_pydict ()
309+ assert result_tuple == result_list
310+
311+
312+ def test_aggregate_tuple_aggs (df ):
313+ result_list = df .aggregate ("a" , [f .count ()]).sort ("a" ).to_pydict ()
314+ result_tuple = df .aggregate ("a" , (f .count (),)).sort ("a" ).to_pydict ()
315+ assert result_tuple == result_list
316+
317+
306318def test_filter_string_unsupported (df ):
307319 with pytest .raises (TypeError , match = re .escape (EXPR_TYPE_ERROR )):
308320 df .filter ("a > 1" )
@@ -416,14 +428,14 @@ def test_with_columns(df):
416428
417429
418430def test_with_columns_invalid_expr (df ):
419- with pytest .raises (
420- TypeError , match = r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
421- ):
431+ with pytest .raises (TypeError , match = re .escape (EXPR_TYPE_ERROR )):
422432 df .with_columns ("a" )
423- with pytest .raises (
424- TypeError , match = r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
425- ):
433+ with pytest .raises (TypeError , match = re .escape (EXPR_TYPE_ERROR )):
426434 df .with_columns (c = "a" )
435+ with pytest .raises (TypeError , match = re .escape (EXPR_TYPE_ERROR )):
436+ df .with_columns (["a" ])
437+ with pytest .raises (TypeError , match = re .escape (EXPR_TYPE_ERROR )):
438+ df .with_columns (c = ["a" ])
427439
428440
429441def test_cast (df ):
@@ -843,6 +855,27 @@ def test_window_functions(partitioned_df, name, expr, result):
843855 assert table .sort_by ("a" ).to_pydict () == expected
844856
845857
858+ @pytest .mark .parametrize ("partition" , ["c" , df_col ("c" )])
859+ def test_rank_partition_by_accepts_string (partitioned_df , partition ):
860+ """Passing a string to partition_by should match using col()."""
861+ df = partitioned_df .select (
862+ f .rank (order_by = column ("a" ), partition_by = partition ).alias ("r" )
863+ )
864+ table = pa .Table .from_batches (df .sort (column ("a" )).collect ())
865+ assert table .column ("r" ).to_pylist () == [1 , 2 , 3 , 4 , 1 , 2 , 3 ]
866+
867+
868+ @pytest .mark .parametrize ("partition" , ["c" , df_col ("c" )])
869+ def test_window_partition_by_accepts_string (partitioned_df , partition ):
870+ """Window.partition_by accepts string identifiers."""
871+ expr = f .first_value (column ("a" )).over (
872+ Window (partition_by = partition , order_by = column ("b" ))
873+ )
874+ df = partitioned_df .select (expr .alias ("fv" ))
875+ table = pa .Table .from_batches (df .sort (column ("a" )).collect ())
876+ assert table .column ("fv" ).to_pylist () == [1 , 1 , 1 , 1 , 5 , 5 , 5 ]
877+
878+
846879@pytest .mark .parametrize (
847880 ("units" , "start_bound" , "end_bound" ),
848881 [
@@ -913,6 +946,69 @@ def test_window_frame_defaults_match_postgres(partitioned_df):
913946 assert df_2 .sort (col_a ).to_pydict () == expected
914947
915948
949+ def _build_last_value_df (df ):
950+ return df .select (
951+ f .last_value (column ("a" ))
952+ .over (
953+ Window (
954+ partition_by = [column ("c" )],
955+ order_by = [column ("b" )],
956+ window_frame = WindowFrame ("rows" , None , None ),
957+ )
958+ )
959+ .alias ("expr" ),
960+ f .last_value (column ("a" ))
961+ .over (
962+ Window (
963+ partition_by = [column ("c" )],
964+ order_by = "b" ,
965+ window_frame = WindowFrame ("rows" , None , None ),
966+ )
967+ )
968+ .alias ("str" ),
969+ )
970+
971+
972+ def _build_nth_value_df (df ):
973+ return df .select (
974+ f .nth_value (column ("b" ), 3 ).over (Window (order_by = [column ("a" )])).alias ("expr" ),
975+ f .nth_value (column ("b" ), 3 ).over (Window (order_by = "a" )).alias ("str" ),
976+ )
977+
978+
979+ def _build_rank_df (df ):
980+ return df .select (
981+ f .rank (order_by = [column ("b" )]).alias ("expr" ),
982+ f .rank (order_by = "b" ).alias ("str" ),
983+ )
984+
985+
986+ def _build_array_agg_df (df ):
987+ return df .aggregate (
988+ [column ("c" )],
989+ [
990+ f .array_agg (column ("a" ), order_by = [column ("a" )]).alias ("expr" ),
991+ f .array_agg (column ("a" ), order_by = "a" ).alias ("str" ),
992+ ],
993+ ).sort (column ("c" ))
994+
995+
996+ @pytest .mark .parametrize (
997+ ("builder" , "expected" ),
998+ [
999+ pytest .param (_build_last_value_df , [3 , 3 , 3 , 3 , 6 , 6 , 6 ], id = "last_value" ),
1000+ pytest .param (_build_nth_value_df , [None , None , 7 , 7 , 7 , 7 , 7 ], id = "nth_value" ),
1001+ pytest .param (_build_rank_df , [1 , 1 , 3 , 3 , 5 , 6 , 6 ], id = "rank" ),
1002+ pytest .param (_build_array_agg_df , [[0 , 1 , 2 , 3 ], [4 , 5 , 6 ]], id = "array_agg" ),
1003+ ],
1004+ )
1005+ def test_order_by_string_equivalence (partitioned_df , builder , expected ):
1006+ df = builder (partitioned_df )
1007+ table = pa .Table .from_batches (df .collect ())
1008+ assert table .column ("expr" ).to_pylist () == expected
1009+ assert table .column ("expr" ).to_pylist () == table .column ("str" ).to_pylist ()
1010+
1011+
9161012def test_html_formatter_cell_dimension (df , clean_formatter_state ):
9171013 """Test configuring the HTML formatter with different options."""
9181014 # Configure with custom settings
0 commit comments