Skip to content

Commit 0ef75b1

Browse files
committed
feat: enhance WindowFrame to accept both scalar and array inputs in bounds
1 parent a3a3979 commit 0ef75b1

3 files changed

Lines changed: 29 additions & 6 deletions

File tree

python/datafusion/expr.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from typing import TYPE_CHECKING, Any, Optional
2626

27-
import pyarrow as pa
28-
2927
try:
3028
from warnings import deprecated # Python 3.13+
3129
except ImportError:
@@ -1206,12 +1204,20 @@ def __init__(
12061204
will be set to unbounded. If unit type is ``groups``, this
12071205
parameter must be set.
12081206
"""
1209-
if not isinstance(start_bound, pa.Scalar) and start_bound is not None:
1210-
start_bound = pa.scalar(start_bound)
1207+
if start_bound is not None:
1208+
if not isinstance(start_bound, (pa.Array, pa.ChunkedArray)):
1209+
if isinstance(start_bound, pa.Scalar):
1210+
start_bound = pa.array([start_bound.as_py()])
1211+
else:
1212+
start_bound = pa.array([start_bound])
12111213
if units in ("rows", "groups"):
12121214
start_bound = start_bound.cast(pa.uint64())
1213-
if not isinstance(end_bound, pa.Scalar) and end_bound is not None:
1214-
end_bound = pa.scalar(end_bound)
1215+
if end_bound is not None:
1216+
if not isinstance(end_bound, (pa.Array, pa.ChunkedArray)):
1217+
if isinstance(end_bound, pa.Scalar):
1218+
end_bound = pa.array([end_bound.as_py()])
1219+
else:
1220+
end_bound = pa.array([end_bound])
12151221
if units in ("rows", "groups"):
12161222
end_bound = end_bound.cast(pa.uint64())
12171223
self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound)

python/tests/test_dataframe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,11 @@ def test_window_functions(partitioned_df, name, expr, result):
767767
]
768768
+ [
769769
("groups", 0, 0),
770+
(
771+
"rows",
772+
pa.array([0], type=pa.uint64()),
773+
pa.array([1], type=pa.uint64()),
774+
),
770775
],
771776
)
772777
def test_valid_window_frame(units, start_bound, end_bound):

python/tests/test_udwf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,18 @@ def test_register_udwf(ctx, count_window_df):
312312
assert result.column(0) == pa.array([0, 1, 2])
313313

314314

315+
def test_window_frame_accepts_scalar_and_none():
316+
wf = WindowFrame("rows", pa.scalar(1), None)
317+
assert wf.get_lower_bound().get_offset() == 1
318+
assert wf.get_upper_bound().is_unbounded()
319+
320+
321+
def test_window_frame_accepts_arrays():
322+
wf = WindowFrame("rows", pa.array([1]), pa.array([2]))
323+
assert wf.get_lower_bound().get_offset() == 1
324+
assert wf.get_upper_bound().get_offset() == 2
325+
326+
315327
smooth_default = udwf(
316328
ExponentialSmoothDefault,
317329
pa.float64(),

0 commit comments

Comments
 (0)