Skip to content

Commit 5079d94

Browse files
committed
refactor: update order_by handling in Window class for improved type support
1 parent 15737a2 commit 5079d94

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

python/datafusion/expr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def over(self, window: Window) -> Expr:
696696
window: Window definition
697697
"""
698698
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
699-
order_by_raw = sort_list_to_raw_sort_list(window._order_by)
699+
order_by_raw = window._order_by
700700
window_frame_raw = (
701701
window._window_frame.window_frame
702702
if window._window_frame is not None
@@ -1182,7 +1182,7 @@ def __init__(
11821182
self,
11831183
partition_by: Optional[list[Expr] | Expr] = None,
11841184
window_frame: Optional[WindowFrame] = None,
1185-
order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None,
1185+
order_by: Optional[list[SortExpr | Expr | str] | Expr | SortExpr | str] = None,
11861186
null_treatment: Optional[NullTreatment] = None,
11871187
) -> None:
11881188
"""Construct a window definition.
@@ -1195,7 +1195,7 @@ def __init__(
11951195
"""
11961196
self._partition_by = partition_by
11971197
self._window_frame = window_frame
1198-
self._order_by = order_by
1198+
self._order_by = sort_list_to_raw_sort_list(order_by)
11991199
self._null_treatment = null_treatment
12001200

12011201

python/tests/test_dataframe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
WindowFrame,
3434
column,
3535
literal,
36-
col,
3736
)
3837
from datafusion import (
3938
functions as f,
@@ -292,6 +291,8 @@ def test_sort_unsupported(df):
292291

293292

294293
def test_aggregate_string_and_expression_equivalent(df):
294+
from datafusion import col
295+
295296
result_str = df.aggregate("a", [f.count()]).sort("a").to_pydict()
296297
result_expr = df.aggregate(col("a"), [f.count()]).sort("a").to_pydict()
297298
assert result_str == result_expr
@@ -778,6 +779,13 @@ def test_distinct():
778779
),
779780
[1, 1, 1, 1, 5, 5, 5],
780781
),
782+
(
783+
"first_value_order_by_string",
784+
f.first_value(column("a")).over(
785+
Window(partition_by=[column("c")], order_by="b")
786+
),
787+
[1, 1, 1, 1, 5, 5, 5],
788+
),
781789
(
782790
"last_value",
783791
f.last_value(column("a")).over(

0 commit comments

Comments
 (0)