Skip to content

Commit e80d931

Browse files
committed
refactor: consolidate order by string equivalence tests into parameterized test
1 parent c58b324 commit e80d931

1 file changed

Lines changed: 44 additions & 29 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -925,51 +925,66 @@ def test_window_frame_defaults_match_postgres(partitioned_df):
925925
assert df_2.sort(col_a).to_pydict() == expected
926926

927927

928-
def test_last_value_order_by_string_equivalence(partitioned_df):
929-
expr = f.last_value(column("a")).over(
930-
Window(
931-
partition_by=[column("c")],
932-
order_by=[column("b")],
933-
window_frame=WindowFrame("rows", None, None),
928+
def _build_last_value_df(df):
929+
return df.select(
930+
f.last_value(column("a"))
931+
.over(
932+
Window(
933+
partition_by=[column("c")],
934+
order_by=[column("b")],
935+
window_frame=WindowFrame("rows", None, None),
936+
)
934937
)
935-
)
936-
string = f.last_value(column("a")).over(
937-
Window(
938-
partition_by=[column("c")],
939-
order_by="b",
940-
window_frame=WindowFrame("rows", None, None),
938+
.alias("expr"),
939+
f.last_value(column("a"))
940+
.over(
941+
Window(
942+
partition_by=[column("c")],
943+
order_by="b",
944+
window_frame=WindowFrame("rows", None, None),
945+
)
941946
)
947+
.alias("str"),
942948
)
943-
df = partitioned_df.select(expr.alias("expr"), string.alias("str"))
944-
table = pa.Table.from_batches(df.collect())
945-
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
946949

947950

948-
def test_nth_value_order_by_string_equivalence(partitioned_df):
949-
expr = f.nth_value(column("b"), 3).over(Window(order_by=[column("a")]))
950-
string = f.nth_value(column("b"), 3).over(Window(order_by="a"))
951-
df = partitioned_df.select(expr.alias("expr"), string.alias("str"))
952-
table = pa.Table.from_batches(df.collect())
953-
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
951+
def _build_nth_value_df(df):
952+
return df.select(
953+
f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])).alias("expr"),
954+
f.nth_value(column("b"), 3).over(Window(order_by="a")).alias("str"),
955+
)
954956

955957

956-
def test_rank_order_by_string_equivalence(partitioned_df):
957-
expr = f.rank(order_by=[column("b")])
958-
string = f.rank(order_by="b")
959-
df = partitioned_df.select(expr.alias("expr"), string.alias("str"))
960-
table = pa.Table.from_batches(df.collect())
961-
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
958+
def _build_rank_df(df):
959+
return df.select(
960+
f.rank(order_by=[column("b")]).alias("expr"),
961+
f.rank(order_by="b").alias("str"),
962+
)
962963

963964

964-
def test_array_agg_order_by_string_equivalence(partitioned_df):
965-
df = partitioned_df.aggregate(
965+
def _build_array_agg_df(df):
966+
return df.aggregate(
966967
[column("c")],
967968
[
968969
f.array_agg(column("a"), order_by=[column("a")]).alias("expr"),
969970
f.array_agg(column("a"), order_by="a").alias("str"),
970971
],
971972
).sort(column("c"))
973+
974+
975+
@pytest.mark.parametrize(
976+
("builder", "expected"),
977+
[
978+
pytest.param(_build_last_value_df, [3, 3, 3, 3, 6, 6, 6], id="last_value"),
979+
pytest.param(_build_nth_value_df, [None, None, 7, 7, 7, 7, 7], id="nth_value"),
980+
pytest.param(_build_rank_df, [1, 1, 3, 3, 5, 6, 6], id="rank"),
981+
pytest.param(_build_array_agg_df, [[0, 1, 2, 3], [4, 5, 6]], id="array_agg"),
982+
],
983+
)
984+
def test_order_by_string_equivalence(partitioned_df, builder, expected):
985+
df = builder(partitioned_df)
972986
table = pa.Table.from_batches(df.collect())
987+
assert table.column("expr").to_pylist() == expected
973988
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
974989

975990

0 commit comments

Comments
 (0)