Skip to content

Commit 05b5bcf

Browse files
fixed various issues after PR review
1 parent e397c5c commit 05b5bcf

File tree

3 files changed

+51
-19
lines changed

3 files changed

+51
-19
lines changed

_duckdb-stubs/__init__.pyi

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ if typing.TYPE_CHECKING:
4040
ColumnsTypes,
4141
ProfilerFormat,
4242
ParquetCompression,
43+
ArrowUDF,
4344
)
44-
from ._enums import ExplainTypeLiteral, CSVLineTerminatorLiteral, RenderModeLiteral
45+
from ._enums import ExplainTypeLiteral, RenderModeLiteral
4546
from duckdb import sqltypes, func
4647

4748
__all__: lst[str] = [
@@ -210,14 +211,28 @@ class DuckDBPyConnection:
210211
def checkpoint(self) -> DuckDBPyConnection: ...
211212
def close(self) -> None: ...
212213
def commit(self) -> DuckDBPyConnection: ...
214+
@typing.overload
213215
def create_function(
214216
self,
215217
name: str,
216218
function: Callable[..., PythonLiteral],
217219
parameters: lst[IntoPyType] | None = None,
218220
return_type: IntoPyType | None = None,
219221
*,
220-
type: func.PythonUDFType = ...,
222+
type: func.PythonUDFType = func.PythonUDFType.NATIVE,
223+
null_handling: func.FunctionNullHandling = ...,
224+
exception_handling: PythonExceptionHandling = ...,
225+
side_effects: bool = False,
226+
) -> DuckDBPyConnection: ...
227+
@typing.overload
228+
def create_function(
229+
self,
230+
name: str,
231+
function: ArrowUDF,
232+
parameters: lst[IntoPyType] | None = None,
233+
return_type: IntoPyType | None = None,
234+
*,
235+
type: func.PythonUDFType = func.PythonUDFType.ARROW,
221236
null_handling: func.FunctionNullHandling = ...,
222237
exception_handling: PythonExceptionHandling = ...,
223238
side_effects: bool = False,
@@ -273,7 +288,7 @@ class DuckDBPyConnection:
273288
normalize_names: bool | None = None,
274289
null_padding: bool | None = None,
275290
names: lst[str] | None = None,
276-
lineterminator: CSVLineTerminator | CSVLineTerminatorLiteral | None = None,
291+
lineterminator: CSVLineTerminator | None = None,
277292
columns: ColumnsTypes | None = None,
278293
auto_type_candidates: lst[StrIntoPyType] | None = None,
279294
max_line_size: int | None = None,
@@ -374,7 +389,7 @@ class DuckDBPyConnection:
374389
normalize_names: bool | None = None,
375390
null_padding: bool | None = None,
376391
names: lst[str] | None = None,
377-
lineterminator: CSVLineTerminator | CSVLineTerminatorLiteral | None = None,
392+
lineterminator: CSVLineTerminator | None = None,
378393
columns: ColumnsTypes | None = None,
379394
auto_type_candidates: lst[StrIntoPyType] | None = None,
380395
max_line_size: int | None = None,
@@ -596,7 +611,7 @@ class DuckDBPyRelation:
596611
self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = ""
597612
) -> DuckDBPyRelation: ...
598613
def map(
599-
self, map_function: Callable[..., PythonLiteral], *, schema: dict[str, sqltypes.DuckDBPyType] | None = None
614+
self, map_function: Callable[..., typing.Any], *, schema: dict[str, sqltypes.DuckDBPyType] | None = None
600615
) -> DuckDBPyRelation: ...
601616
def max(
602617
self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = ""
@@ -900,13 +915,27 @@ def connect(
900915
read_only: bool = False,
901916
config: dict[str, str | bool | int | float | lst[str]] | None = None,
902917
) -> DuckDBPyConnection: ...
918+
@typing.overload
903919
def create_function(
904920
name: str,
905921
function: Callable[..., PythonLiteral],
906922
parameters: lst[IntoPyType] | None = None,
907923
return_type: IntoPyType | None = None,
908924
*,
909-
type: func.PythonUDFType = ...,
925+
type: func.PythonUDFType = func.PythonUDFType.NATIVE,
926+
null_handling: func.FunctionNullHandling = ...,
927+
exception_handling: PythonExceptionHandling = ...,
928+
side_effects: bool = False,
929+
connection: DuckDBPyConnection | None = None,
930+
) -> DuckDBPyConnection: ...
931+
@typing.overload
932+
def create_function(
933+
name: str,
934+
function: ArrowUDF,
935+
parameters: lst[IntoPyType] | None = None,
936+
return_type: IntoPyType | None = None,
937+
*,
938+
type: func.PythonUDFType = func.PythonUDFType.ARROW,
910939
null_handling: func.FunctionNullHandling = ...,
911940
exception_handling: PythonExceptionHandling = ...,
912941
side_effects: bool = False,
@@ -1011,7 +1040,7 @@ def from_csv_auto(
10111040
normalize_names: bool | None = None,
10121041
null_padding: bool | None = None,
10131042
names: lst[str] | None = None,
1014-
lineterminator: CSVLineTerminator | CSVLineTerminatorLiteral | None = None,
1043+
lineterminator: CSVLineTerminator | None = None,
10151044
columns: ColumnsTypes | None = None,
10161045
auto_type_candidates: lst[StrIntoPyType] | None = None,
10171046
max_line_size: int | None = None,
@@ -1160,7 +1189,7 @@ def read_csv(
11601189
normalize_names: bool | None = None,
11611190
null_padding: bool | None = None,
11621191
names: lst[str] | None = None,
1163-
lineterminator: CSVLineTerminator | CSVLineTerminatorLiteral | None = None,
1192+
lineterminator: CSVLineTerminator | None = None,
11641193
columns: ColumnsTypes | None = None,
11651194
auto_type_candidates: lst[StrIntoPyType] | None = None,
11661195
max_line_size: int | None = None,

_duckdb-stubs/_enums.pyi

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class CSVLineTerminator(CppEnum):
2828
dict[str, CSVLineTerminator]
2929
] # value = {'LINE_FEED': <CSVLineTerminator.LINE_FEED: 0>, 'CARRIAGE_RETURN_LINE_FEED': <CSVLineTerminator.CARRIAGE_RETURN_LINE_FEED: 1>} # noqa: E501
3030

31-
CSVLineTerminatorLiteral: TypeAlias = Literal["\\r\\n", "\\n"]
32-
3331
class ExpectedResultType(CppEnum):
3432
CHANGED_ROWS: ClassVar[ExpectedResultType] # value = <ExpectedResultType.CHANGED_ROWS: 1>
3533
NOTHING: ClassVar[ExpectedResultType] # value = <ExpectedResultType.NOTHING: 2>
@@ -38,32 +36,32 @@ class ExpectedResultType(CppEnum):
3836
dict[str, ExpectedResultType]
3937
] # value = {'QUERY_RESULT': <ExpectedResultType.QUERY_RESULT: 0>, 'CHANGED_ROWS': <ExpectedResultType.CHANGED_ROWS: 1>, 'NOTHING': <ExpectedResultType.NOTHING: 2>} # noqa: E501
4038

41-
class ExplainType:
39+
class ExplainType(CppEnum):
4240
ANALYZE: ClassVar[ExplainType] # value = <ExplainType.ANALYZE: 1>
4341
STANDARD: ClassVar[ExplainType] # value = <ExplainType.STANDARD: 0>
4442
__members__: ClassVar[
4543
dict[str, ExplainType]
4644
] # value = {'STANDARD': <ExplainType.STANDARD: 0>, 'ANALYZE': <ExplainType.ANALYZE: 1>}
4745

48-
ExplainTypeLiteral: TypeAlias = Literal["analyze", "standard"]
46+
ExplainTypeLiteral: TypeAlias = Literal["analyze", "standard", "ANALYZE", "STANDARD"]
4947

50-
class PythonExceptionHandling:
48+
class PythonExceptionHandling(CppEnum):
5149
DEFAULT: ClassVar[PythonExceptionHandling] # value = <PythonExceptionHandling.DEFAULT: 0>
5250
RETURN_NULL: ClassVar[PythonExceptionHandling] # value = <PythonExceptionHandling.RETURN_NULL: 1>
5351
__members__: ClassVar[
5452
dict[str, PythonExceptionHandling]
5553
] # value = {'DEFAULT': <PythonExceptionHandling.DEFAULT: 0>, 'RETURN_NULL': <PythonExceptionHandling.RETURN_NULL: 1>} # noqa: E501
5654

57-
class RenderMode:
55+
class RenderMode(CppEnum):
5856
COLUMNS: ClassVar[RenderMode] # value = <RenderMode.COLUMNS: 1>
5957
ROWS: ClassVar[RenderMode] # value = <RenderMode.ROWS: 0>
6058
__members__: ClassVar[
6159
dict[str, RenderMode]
6260
] # value = {'ROWS': <RenderMode.ROWS: 0>, 'COLUMNS': <RenderMode.COLUMNS: 1>}
6361

64-
RenderModeLiteral: TypeAlias = Literal["columns", "rows"]
62+
RenderModeLiteral: TypeAlias = Literal["columns", "rows", "COLUMNS", "ROWS"]
6563

66-
class StatementType:
64+
class StatementType(CppEnum):
6765
ALTER: ClassVar[StatementType] # value = <StatementType.ALTER: 8>
6866
ANALYZE: ClassVar[StatementType] # value = <StatementType.ANALYZE: 11>
6967
ATTACH: ClassVar[StatementType] # value = <StatementType.ATTACH: 25>
@@ -98,7 +96,7 @@ class StatementType:
9896
dict[str, StatementType]
9997
] # value = {'INVALID': <StatementType.INVALID: 0>, 'SELECT': <StatementType.SELECT: 1>, 'INSERT': <StatementType.INSERT: 2>, 'UPDATE': <StatementType.UPDATE: 3>, 'CREATE': <StatementType.CREATE: 4>, 'DELETE': <StatementType.DELETE: 5>, 'PREPARE': <StatementType.PREPARE: 6>, 'EXECUTE': <StatementType.EXECUTE: 7>, 'ALTER': <StatementType.ALTER: 8>, 'TRANSACTION': <StatementType.TRANSACTION: 9>, 'COPY': <StatementType.COPY: 10>, 'ANALYZE': <StatementType.ANALYZE: 11>, 'VARIABLE_SET': <StatementType.VARIABLE_SET: 12>, 'CREATE_FUNC': <StatementType.CREATE_FUNC: 13>, 'EXPLAIN': <StatementType.EXPLAIN: 14>, 'DROP': <StatementType.DROP: 15>, 'EXPORT': <StatementType.EXPORT: 16>, 'PRAGMA': <StatementType.PRAGMA: 17>, 'VACUUM': <StatementType.VACUUM: 18>, 'CALL': <StatementType.CALL: 19>, 'SET': <StatementType.SET: 20>, 'LOAD': <StatementType.LOAD: 21>, 'RELATION': <StatementType.RELATION: 22>, 'EXTENSION': <StatementType.EXTENSION: 23>, 'LOGICAL_PLAN': <StatementType.LOGICAL_PLAN: 24>, 'ATTACH': <StatementType.ATTACH: 25>, 'DETACH': <StatementType.DETACH: 26>, 'MULTI': <StatementType.MULTI: 27>, 'COPY_DATABASE': <StatementType.COPY_DATABASE: 28>, 'MERGE_INTO': <StatementType.MERGE_INTO: 30>} # noqa: E501
10098

101-
class token_type:
99+
class token_type(CppEnum):
102100
__members__: ClassVar[
103101
dict[str, token_type]
104102
] # value = {'identifier': <token_type.identifier: 0>, 'numeric_const': <token_type.numeric_const: 1>, 'string_const': <token_type.string_const: 2>, 'operator': <token_type.operator: 3>, 'keyword': <token_type.keyword: 4>, 'comment': <token_type.comment: 5>} # noqa: E501

_duckdb-stubs/_typing.pyi

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ from typing import TypeAlias, TYPE_CHECKING, Protocol, Any, TypeVar, Generic, Li
44
from datetime import date, datetime, time, timedelta
55
from decimal import Decimal
66
from uuid import UUID
7-
from collections.abc import Mapping, Iterator, Sequence
7+
from collections.abc import Mapping, Iterator, Sequence, Callable
88

99
if TYPE_CHECKING:
10+
import pyarrow as pa
1011
from ._expression import Expression
1112
from ._sqltypes import DuckDBPyType
1213

@@ -101,6 +102,7 @@ Builtins: TypeAlias = Literal[
101102
"date",
102103
"double",
103104
"float",
105+
"geometry",
104106
"hugeint",
105107
"integer",
106108
"interval",
@@ -189,7 +191,7 @@ ParquetFieldsOptions: TypeAlias = _Auto | ParquetFieldIdsType
189191
"""Types accepted for the `field_ids` parameter in parquet writing methods."""
190192

191193
CsvEncoding: TypeAlias = Literal["utf-8", "utf-16", "latin-1"] | str
192-
"""Encdoding options.
194+
"""Encoding options.
193195
194196
All availables options not in the literal values can be seen here:
195197
https://duckdb.org/docs/stable/core_extensions/encodings
@@ -213,3 +215,6 @@ JoinType: TypeAlias = Literal["inner", "left", "right", "outer", "semi", "anti"]
213215

214216
ProfilerFormat: TypeAlias = Literal["json", "query_tree", "query_tree_optimizer", "no_output", "html", "graphviz"]
215217
"""Formats available in `get_profiling_information` method/function."""
218+
# TODO: this should be a `Protocol` just like `NPArrayLike`.
219+
ArrowUDF: TypeAlias = Callable[..., pa.Table | pa.Array | pa.ChunkedArray]
220+
"""Type accepted for Python UDFs that return Arrow data."""

0 commit comments

Comments
 (0)