Skip to content

Commit 5395b0d

Browse files
committed
feat: enhance aggregate method to support tuple inputs for group_by and aggs
1 parent 28a3e65 commit 5395b0d

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

python/datafusion/dataframe.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import annotations
2323

2424
import warnings
25+
from collections.abc import Sequence
2526
from typing import (
2627
TYPE_CHECKING,
2728
Any,
@@ -52,7 +53,7 @@
5253

5354
if TYPE_CHECKING:
5455
import pathlib
55-
from typing import Callable, Sequence
56+
from typing import Callable
5657

5758
import pandas as pd
5859
import polars as pl
@@ -523,20 +524,28 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
523524

524525
def aggregate(
525526
self,
526-
group_by: list[Expr | str] | Expr | str,
527-
aggs: list[Expr] | Expr,
527+
group_by: Sequence[Expr | str] | Expr | str,
528+
aggs: Sequence[Expr] | Expr,
528529
) -> DataFrame:
529530
"""Aggregates the rows of the current DataFrame.
530531
531532
Args:
532-
group_by: List of expressions or column names to group by.
533-
aggs: List of expressions to aggregate.
533+
group_by: Sequence of expressions or column names to group by.
534+
aggs: Sequence of expressions to aggregate.
534535
535536
Returns:
536537
DataFrame after aggregation.
537538
"""
538-
group_by_list = group_by if isinstance(group_by, list) else [group_by]
539-
aggs_list = aggs if isinstance(aggs, list) else [aggs]
539+
group_by_list = (
540+
list(group_by)
541+
if isinstance(group_by, Sequence) and not isinstance(group_by, (Expr, str))
542+
else [group_by]
543+
)
544+
aggs_list = (
545+
list(aggs)
546+
if isinstance(aggs, Sequence) and not isinstance(aggs, Expr)
547+
else [aggs]
548+
)
540549

541550
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
542551
aggs_exprs = [ensure_expr(agg) for agg in aggs_list]

python/tests/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,18 @@ def test_aggregate_string_and_expression_equivalent(df):
303303
assert result_str == result_expr
304304

305305

306+
def test_aggregate_tuple_group_by(df):
307+
result_list = df.aggregate(["a"], [f.count()]).sort("a").to_pydict()
308+
result_tuple = df.aggregate(("a",), [f.count()]).sort("a").to_pydict()
309+
assert result_tuple == result_list
310+
311+
312+
def test_aggregate_tuple_aggs(df):
313+
result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict()
314+
result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict()
315+
assert result_tuple == result_list
316+
317+
306318
def test_filter_string_unsupported(df):
307319
with pytest.raises(TypeError, match=re.escape(EXPR_TYPE_ERROR)):
308320
df.filter("a > 1")

0 commit comments

Comments
 (0)