Skip to content

Commit 6570061

Browse files
committed
refactor: replace Expr and SortExpr with SortKey in file_sort_order and related functions
1 parent 9aa50ec commit 6570061

4 files changed

Lines changed: 26 additions & 22 deletions

File tree

python/datafusion/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from datafusion.catalog import Catalog, CatalogProvider, Table
3333
from datafusion.dataframe import DataFrame
34-
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
34+
from datafusion.expr import SortKey, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
3636
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
3737

@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
556+
file_sort_order: list[list[SortKey]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -805,7 +805,7 @@ def register_parquet(
805805
file_extension: str = ".parquet",
806806
skip_metadata: bool = True,
807807
schema: pa.Schema | None = None,
808-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
808+
file_sort_order: list[list[SortKey]] | None = None,
809809
) -> None:
810810
"""Register a Parquet file as a table.
811811
@@ -1096,7 +1096,7 @@ def read_parquet(
10961096
file_extension: str = ".parquet",
10971097
skip_metadata: bool = True,
10981098
schema: pa.Schema | None = None,
1099-
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
1099+
file_sort_order: list[list[SortKey]] | None = None,
11001100
) -> DataFrame:
11011101
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11021102

python/datafusion/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from datafusion.expr import (
4444
EXPR_TYPE_ERROR,
4545
Expr,
46-
SortExpr,
46+
SortKey,
4747
expr_list_to_raw_expr_list,
4848
sort_list_to_raw_sort_list,
4949
)
@@ -540,7 +540,7 @@ def aggregate(
540540
aggs_exprs.append(agg.expr)
541541
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
542542

543-
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
543+
def sort(self, *exprs: SortKey) -> DataFrame:
544544
"""Sort the DataFrame by the specified sorting expressions or column names.
545545
546546
Note that any expression can be turned into a sort expression by

python/datafusion/expr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Union
2626

2727
import pyarrow as pa
2828

@@ -43,6 +43,8 @@
4343
# Standard error message for invalid expression types
4444
EXPR_TYPE_ERROR = "Use col() or lit() to construct expressions"
4545

46+
SortKey = Union["Expr", "SortExpr", str]
47+
4648
# The following are imported from the internal representation. We may choose to
4749
# give these all proper wrappers, or to simply leave as is. These were added
4850
# in order to support passing the `test_imports` unit test.
@@ -199,6 +201,7 @@
199201
"SimilarTo",
200202
"Sort",
201203
"SortExpr",
204+
"SortKey",
202205
"Subquery",
203206
"SubqueryAlias",
204207
"TableScan",
@@ -250,7 +253,7 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
250253

251254

252255
def sort_list_to_raw_sort_list(
253-
sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str],
256+
sort_list: Optional[list[SortKey] | SortKey],
254257
) -> Optional[list[expr_internal.SortExpr]]:
255258
"""Helper function to return an optional sort list to raw variant."""
256259
if isinstance(sort_list, (Expr, SortExpr, str)):

python/datafusion/functions.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
CaseBuilder,
2929
Expr,
3030
SortExpr,
31+
SortKey,
3132
WindowFrame,
3233
expr_list_to_raw_expr_list,
3334
sort_list_to_raw_sort_list,
@@ -429,7 +430,7 @@ def window(
429430
name: str,
430431
args: list[Expr],
431432
partition_by: list[Expr] | Expr | None = None,
432-
order_by: list[Expr | SortExpr | str] | Expr | SortExpr | str | None = None,
433+
order_by: list[SortKey] | SortKey | None = None,
433434
window_frame: WindowFrame | None = None,
434435
ctx: SessionContext | None = None,
435436
) -> Expr:
@@ -1723,7 +1724,7 @@ def array_agg(
17231724
expression: Expr,
17241725
distinct: bool = False,
17251726
filter: Optional[Expr] = None,
1726-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
1727+
order_by: Optional[list[SortKey] | SortKey] = None,
17271728
) -> Expr:
17281729
"""Aggregate values into an array.
17291730
@@ -2222,7 +2223,7 @@ def regr_syy(
22222223
def first_value(
22232224
expression: Expr,
22242225
filter: Optional[Expr] = None,
2225-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2226+
order_by: Optional[list[SortKey] | SortKey] = None,
22262227
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22272228
) -> Expr:
22282229
"""Returns the first value in a group of values.
@@ -2254,7 +2255,7 @@ def first_value(
22542255
def last_value(
22552256
expression: Expr,
22562257
filter: Optional[Expr] = None,
2257-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2258+
order_by: Optional[list[SortKey] | SortKey] = None,
22582259
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22592260
) -> Expr:
22602261
"""Returns the last value in a group of values.
@@ -2287,7 +2288,7 @@ def nth_value(
22872288
expression: Expr,
22882289
n: int,
22892290
filter: Optional[Expr] = None,
2290-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2291+
order_by: Optional[list[SortKey] | SortKey] = None,
22912292
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22922293
) -> Expr:
22932294
"""Returns the n-th value in a group of values.
@@ -2408,7 +2409,7 @@ def lead(
24082409
shift_offset: int = 1,
24092410
default_value: Optional[Any] = None,
24102411
partition_by: Optional[list[Expr] | Expr] = None,
2411-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2412+
order_by: Optional[list[SortKey] | SortKey] = None,
24122413
) -> Expr:
24132414
"""Create a lead window function.
24142415
@@ -2461,7 +2462,7 @@ def lag(
24612462
shift_offset: int = 1,
24622463
default_value: Optional[Any] = None,
24632464
partition_by: Optional[list[Expr] | Expr] = None,
2464-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2465+
order_by: Optional[list[SortKey] | SortKey] = None,
24652466
) -> Expr:
24662467
"""Create a lag window function.
24672468
@@ -2508,7 +2509,7 @@ def lag(
25082509

25092510
def row_number(
25102511
partition_by: Optional[list[Expr] | Expr] = None,
2511-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2512+
order_by: Optional[list[SortKey] | SortKey] = None,
25122513
) -> Expr:
25132514
"""Create a row number window function.
25142515
@@ -2542,7 +2543,7 @@ def row_number(
25422543

25432544
def rank(
25442545
partition_by: Optional[list[Expr] | Expr] = None,
2545-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2546+
order_by: Optional[list[SortKey] | SortKey] = None,
25462547
) -> Expr:
25472548
"""Create a rank window function.
25482549
@@ -2581,7 +2582,7 @@ def rank(
25812582

25822583
def dense_rank(
25832584
partition_by: Optional[list[Expr] | Expr] = None,
2584-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2585+
order_by: Optional[list[SortKey] | SortKey] = None,
25852586
) -> Expr:
25862587
"""Create a dense_rank window function.
25872588
@@ -2615,7 +2616,7 @@ def dense_rank(
26152616

26162617
def percent_rank(
26172618
partition_by: Optional[list[Expr] | Expr] = None,
2618-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2619+
order_by: Optional[list[SortKey] | SortKey] = None,
26192620
) -> Expr:
26202621
"""Create a percent_rank window function.
26212622
@@ -2650,7 +2651,7 @@ def percent_rank(
26502651

26512652
def cume_dist(
26522653
partition_by: Optional[list[Expr] | Expr] = None,
2653-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2654+
order_by: Optional[list[SortKey] | SortKey] = None,
26542655
) -> Expr:
26552656
"""Create a cumulative distribution window function.
26562657
@@ -2686,7 +2687,7 @@ def cume_dist(
26862687
def ntile(
26872688
groups: int,
26882689
partition_by: Optional[list[Expr] | Expr] = None,
2689-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2690+
order_by: Optional[list[SortKey] | SortKey] = None,
26902691
) -> Expr:
26912692
"""Create a n-tile window function.
26922693
@@ -2727,7 +2728,7 @@ def string_agg(
27272728
expression: Expr,
27282729
delimiter: str,
27292730
filter: Optional[Expr] = None,
2730-
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
2731+
order_by: Optional[list[SortKey] | SortKey] = None,
27312732
) -> Expr:
27322733
"""Concatenates the input strings.
27332734

0 commit comments

Comments
 (0)