Skip to content

Commit 646d0ab

Browse files
committed
feat: add tests for string equivalence in window functions and aggregations
1 parent 5598f90 commit 646d0ab

1 file changed

Lines changed: 48 additions & 0 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,54 @@ def test_window_frame_defaults_match_postgres(partitioned_df):
933933
assert df_2.sort(col_a).to_pydict() == expected
934934

935935

936+
def test_last_value_order_by_string_equivalence(partitioned_df):
937+
expr = f.last_value(column("a")).over(
938+
Window(
939+
partition_by=[column("c")],
940+
order_by=[column("b")],
941+
window_frame=WindowFrame("rows", None, None),
942+
)
943+
)
944+
string = f.last_value(column("a")).over(
945+
Window(
946+
partition_by=[column("c")],
947+
order_by="b",
948+
window_frame=WindowFrame("rows", None, None),
949+
)
950+
)
951+
df = partitioned_df.select(expr.alias("expr"), string.alias("str"))
952+
table = pa.Table.from_batches(df.collect())
953+
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
954+
955+
956+
def test_nth_value_order_by_string_equivalence(partitioned_df):
957+
expr = f.nth_value(column("b"), 3).over(Window(order_by=[column("a")]))
958+
string = f.nth_value(column("b"), 3).over(Window(order_by="a"))
959+
df = partitioned_df.select(expr.alias("expr"), string.alias("str"))
960+
table = pa.Table.from_batches(df.collect())
961+
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
962+
963+
964+
def test_rank_order_by_string_equivalence(partitioned_df):
965+
expr = f.rank(order_by=[column("b")])
966+
string = f.rank(order_by="b")
967+
df = partitioned_df.select(expr.alias("expr"), string.alias("str"))
968+
table = pa.Table.from_batches(df.collect())
969+
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
970+
971+
972+
def test_array_agg_order_by_string_equivalence(partitioned_df):
973+
df = partitioned_df.aggregate(
974+
[column("c")],
975+
[
976+
f.array_agg(column("a"), order_by=[column("a")]).alias("expr"),
977+
f.array_agg(column("a"), order_by="a").alias("str"),
978+
],
979+
).sort(column("c"))
980+
table = pa.Table.from_batches(df.collect())
981+
assert table.column("expr").to_pylist() == table.column("str").to_pylist()
982+
983+
936984
def test_html_formatter_cell_dimension(df, clean_formatter_state):
937985
"""Test configuring the HTML formatter with different options."""
938986
# Configure with custom settings

0 commit comments

Comments
 (0)