Skip to content

Commit 42d6468

Browse files
Fix Expression typing in relational API stubs (#343)
Fixes #341
2 parents 1002489 + d225900 commit 42d6468

2 files changed

Lines changed: 343 additions & 42 deletions

File tree

_duckdb-stubs/__init__.pyi

Lines changed: 61 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import datetime
2+
import decimal
13
import os
24
import pathlib
35
import typing
6+
import uuid
47
from typing_extensions import Self
58

69
if typing.TYPE_CHECKING:
@@ -15,6 +18,22 @@ if typing.TYPE_CHECKING:
1518
# the field_ids argument to to_parquet and write_parquet has a recursive structure
1619
ParquetFieldIdsType = Mapping[str, typing.Union[int, "ParquetFieldIdsType"]]
1720

21+
_ExpressionLike: typing.TypeAlias = typing.Union[
22+
"Expression",
23+
str,
24+
int,
25+
float,
26+
bool,
27+
bytes,
28+
None,
29+
datetime.date,
30+
datetime.datetime,
31+
datetime.time,
32+
datetime.timedelta,
33+
decimal.Decimal,
34+
uuid.UUID,
35+
]
36+
1837
__all__: list[str] = [
1938
"BinderException",
2039
"CSVLineTerminator",
@@ -472,7 +491,7 @@ class DuckDBPyRelation:
472491
def __getitem__(self, name: str) -> DuckDBPyRelation: ...
473492
def __len__(self) -> int: ...
474493
def aggregate(
475-
self, aggr_expr: str | Iterable[Expression | str], group_expr: Expression | str = ""
494+
self, aggr_expr: str | Iterable[_ExpressionLike], group_expr: _ExpressionLike = ""
476495
) -> DuckDBPyRelation: ...
477496
def any_value(
478497
self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = ""
@@ -642,7 +661,7 @@ class DuckDBPyRelation:
642661
def product(
643662
self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = ""
644663
) -> DuckDBPyRelation: ...
645-
def project(self, *args: str | Expression, groups: str = "") -> DuckDBPyRelation: ...
664+
def project(self, *args: _ExpressionLike, groups: str = "") -> DuckDBPyRelation: ...
646665
def quantile(
647666
self,
648667
expression: str,
@@ -671,7 +690,7 @@ class DuckDBPyRelation:
671690
def rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ...
672691
def rank_dense(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ...
673692
def row_number(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ...
674-
def select(self, *args: str | Expression, groups: str = "") -> DuckDBPyRelation: ...
693+
def select(self, *args: _ExpressionLike, groups: str = "") -> DuckDBPyRelation: ...
675694
def select_dtypes(self, types: typing.List[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ...
676695
def select_types(self, types: typing.List[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ...
677696
def set_alias(self, alias: str) -> DuckDBPyRelation: ...
@@ -684,7 +703,7 @@ class DuckDBPyRelation:
684703
null_value: str | None = None,
685704
render_mode: RenderMode | None = None,
686705
) -> None: ...
687-
def sort(self, *args: Expression) -> DuckDBPyRelation: ...
706+
def sort(self, *args: _ExpressionLike) -> DuckDBPyRelation: ...
688707
def sql_query(self) -> str: ...
689708
def std(
690709
self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = ""
@@ -748,7 +767,7 @@ class DuckDBPyRelation:
748767
def torch(self) -> dict[str, typing.Any]: ...
749768
def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ...
750769
def unique(self, unique_aggr: str) -> DuckDBPyRelation: ...
751-
def update(self, set: Expression | str, *, condition: Expression | str | None = None) -> None: ...
770+
def update(self, set: dict[str, _ExpressionLike], *, condition: _ExpressionLike | None = None) -> None: ...
752771
def value_counts(self, expression: str, groups: str = "") -> DuckDBPyRelation: ...
753772
def var(
754773
self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = ""
@@ -856,54 +875,54 @@ class ExplainType:
856875
def value(self) -> int: ...
857876

858877
class Expression:
859-
def __add__(self, other: Expression) -> Expression: ...
860-
def __and__(self, other: Expression) -> Expression: ...
861-
def __div__(self, other: Expression) -> Expression: ...
862-
def __eq__(self, other: Expression) -> Expression: ... # type: ignore[override]
863-
def __floordiv__(self, other: Expression) -> Expression: ...
864-
def __ge__(self, other: Expression) -> Expression: ...
865-
def __gt__(self, other: Expression) -> Expression: ...
878+
def __add__(self, other: _ExpressionLike) -> Expression: ...
879+
def __and__(self, other: _ExpressionLike) -> Expression: ...
880+
def __div__(self, other: _ExpressionLike) -> Expression: ...
881+
def __eq__(self, other: _ExpressionLike) -> Expression: ... # type: ignore[override]
882+
def __floordiv__(self, other: _ExpressionLike) -> Expression: ...
883+
def __ge__(self, other: _ExpressionLike) -> Expression: ...
884+
def __gt__(self, other: _ExpressionLike) -> Expression: ...
866885
@typing.overload
867886
def __init__(self, arg0: str) -> None: ...
868887
@typing.overload
869888
def __init__(self, arg0: typing.Any) -> None: ...
870889
def __invert__(self) -> Expression: ...
871-
def __le__(self, other: Expression) -> Expression: ...
872-
def __lt__(self, other: Expression) -> Expression: ...
873-
def __mod__(self, other: Expression) -> Expression: ...
874-
def __mul__(self, other: Expression) -> Expression: ...
875-
def __ne__(self, other: Expression) -> Expression: ... # type: ignore[override]
890+
def __le__(self, other: _ExpressionLike) -> Expression: ...
891+
def __lt__(self, other: _ExpressionLike) -> Expression: ...
892+
def __mod__(self, other: _ExpressionLike) -> Expression: ...
893+
def __mul__(self, other: _ExpressionLike) -> Expression: ...
894+
def __ne__(self, other: _ExpressionLike) -> Expression: ... # type: ignore[override]
876895
def __neg__(self) -> Expression: ...
877-
def __or__(self, other: Expression) -> Expression: ...
878-
def __pow__(self, other: Expression) -> Expression: ...
879-
def __radd__(self, other: Expression) -> Expression: ...
880-
def __rand__(self, other: Expression) -> Expression: ...
881-
def __rdiv__(self, other: Expression) -> Expression: ...
882-
def __rfloordiv__(self, other: Expression) -> Expression: ...
883-
def __rmod__(self, other: Expression) -> Expression: ...
884-
def __rmul__(self, other: Expression) -> Expression: ...
885-
def __ror__(self, other: Expression) -> Expression: ...
886-
def __rpow__(self, other: Expression) -> Expression: ...
887-
def __rsub__(self, other: Expression) -> Expression: ...
888-
def __rtruediv__(self, other: Expression) -> Expression: ...
889-
def __sub__(self, other: Expression) -> Expression: ...
890-
def __truediv__(self, other: Expression) -> Expression: ...
896+
def __or__(self, other: _ExpressionLike) -> Expression: ...
897+
def __pow__(self, other: _ExpressionLike) -> Expression: ...
898+
def __radd__(self, other: _ExpressionLike) -> Expression: ...
899+
def __rand__(self, other: _ExpressionLike) -> Expression: ...
900+
def __rdiv__(self, other: _ExpressionLike) -> Expression: ...
901+
def __rfloordiv__(self, other: _ExpressionLike) -> Expression: ...
902+
def __rmod__(self, other: _ExpressionLike) -> Expression: ...
903+
def __rmul__(self, other: _ExpressionLike) -> Expression: ...
904+
def __ror__(self, other: _ExpressionLike) -> Expression: ...
905+
def __rpow__(self, other: _ExpressionLike) -> Expression: ...
906+
def __rsub__(self, other: _ExpressionLike) -> Expression: ...
907+
def __rtruediv__(self, other: _ExpressionLike) -> Expression: ...
908+
def __sub__(self, other: _ExpressionLike) -> Expression: ...
909+
def __truediv__(self, other: _ExpressionLike) -> Expression: ...
891910
def alias(self, name: str) -> Expression: ...
892911
def asc(self) -> Expression: ...
893-
def between(self, lower: Expression, upper: Expression) -> Expression: ...
912+
def between(self, lower: _ExpressionLike, upper: _ExpressionLike) -> Expression: ...
894913
def cast(self, type: sqltypes.DuckDBPyType) -> Expression: ...
895914
def collate(self, collation: str) -> Expression: ...
896915
def desc(self) -> Expression: ...
897916
def get_name(self) -> str: ...
898-
def isin(self, *args: Expression) -> Expression: ...
899-
def isnotin(self, *args: Expression) -> Expression: ...
917+
def isin(self, *args: _ExpressionLike) -> Expression: ...
918+
def isnotin(self, *args: _ExpressionLike) -> Expression: ...
900919
def isnotnull(self) -> Expression: ...
901920
def isnull(self) -> Expression: ...
902921
def nulls_first(self) -> Expression: ...
903922
def nulls_last(self) -> Expression: ...
904-
def otherwise(self, value: Expression) -> Expression: ...
923+
def otherwise(self, value: _ExpressionLike) -> Expression: ...
905924
def show(self) -> None: ...
906-
def when(self, condition: Expression, value: Expression) -> Expression: ...
925+
def when(self, condition: _ExpressionLike, value: _ExpressionLike) -> Expression: ...
907926

908927
class FatalException(DatabaseError): ...
909928

@@ -1055,18 +1074,18 @@ class token_type:
10551074
@property
10561075
def value(self) -> int: ...
10571076

1058-
def CaseExpression(condition: Expression, value: Expression) -> Expression: ...
1059-
def CoalesceOperator(*args: Expression) -> Expression: ...
1077+
def CaseExpression(condition: _ExpressionLike, value: _ExpressionLike) -> Expression: ...
1078+
def CoalesceOperator(*args: _ExpressionLike) -> Expression: ...
10601079
def ColumnExpression(*args: str) -> Expression: ...
10611080
def ConstantExpression(value: typing.Any) -> Expression: ...
10621081
def DefaultExpression() -> Expression: ...
1063-
def FunctionExpression(function_name: str, *args: Expression) -> Expression: ...
1064-
def LambdaExpression(lhs: typing.Any, rhs: Expression) -> Expression: ...
1082+
def FunctionExpression(function_name: str, *args: _ExpressionLike) -> Expression: ...
1083+
def LambdaExpression(lhs: typing.Any, rhs: _ExpressionLike) -> Expression: ...
10651084
def SQLExpression(expression: str) -> Expression: ...
10661085
def StarExpression(*, exclude: Iterable[str | Expression] | None = None) -> Expression: ...
10671086
def aggregate(
10681087
df: pandas.DataFrame,
1069-
aggr_expr: str | Iterable[Expression | str],
1088+
aggr_expr: str | Iterable[_ExpressionLike],
10701089
group_expr: str = "",
10711090
*,
10721091
connection: DuckDBPyConnection | None = None,
@@ -1324,7 +1343,7 @@ def pl(
13241343
connection: DuckDBPyConnection | None = None,
13251344
) -> typing.Union[polars.DataFrame, polars.LazyFrame]: ...
13261345
def project(
1327-
df: pandas.DataFrame, *args: str | Expression, groups: str = "", connection: DuckDBPyConnection | None = None
1346+
df: pandas.DataFrame, *args: _ExpressionLike, groups: str = "", connection: DuckDBPyConnection | None = None
13281347
) -> DuckDBPyRelation: ...
13291348
def query(
13301349
query: Statement | str,

0 commit comments

Comments
 (0)