|
22 | 22 |
|
23 | 23 | from __future__ import annotations |
24 | 24 |
|
25 | | -from typing import TYPE_CHECKING, Any, ClassVar, Optional |
| 25 | +from typing import TYPE_CHECKING, Any, Optional |
26 | 26 |
|
27 | 27 | import pyarrow as pa |
28 | 28 |
|
|
32 | 32 | from typing_extensions import deprecated # Python 3.12 |
33 | 33 |
|
34 | 34 | from datafusion.common import DataTypeMap, NullTreatment, RexType |
| 35 | +from datafusion.types import DataType |
35 | 36 |
|
36 | 37 | from ._internal import expr as expr_internal |
37 | 38 | from ._internal import functions as functions_internal |
@@ -515,21 +516,34 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr: |
515 | 516 | value = Expr.literal(value) |
516 | 517 | return Expr(functions_internal.nvl(self.expr, value.expr)) |
517 | 518 |
|
518 | | - _to_pyarrow_types: ClassVar[dict[type, pa.DataType]] = { |
519 | | - float: pa.float64(), |
520 | | - int: pa.int64(), |
521 | | - str: pa.string(), |
522 | | - bool: pa.bool_(), |
523 | | - } |
| 519 | + def cast(self, to: DataType | type[float | int | str | bool] | object) -> Expr: |
| 520 | + """Cast to a new data type. |
524 | 521 |
|
525 | | - def cast(self, to: pa.DataType[Any] | type[float | int | str | bool]) -> Expr: |
526 | | - """Cast to a new data type.""" |
527 | | - if not isinstance(to, pa.DataType): |
| 522 | + Args: |
| 523 | + to: Target type as :class:`datafusion.types.DataType`, a Python |
| 524 | + builtin (:class:`float`, :class:`int`, :class:`str`, |
| 525 | + :class:`bool`), or any object implementing |
| 526 | + ``__arrow_c_schema__``. |
| 527 | + """ |
| 528 | + if isinstance(to, type): |
| 529 | + mapping = {float: "Float64", int: "Int64", str: "Utf8", bool: "Boolean"} |
528 | 530 | try: |
529 | | - to = self._to_pyarrow_types[to] |
| 531 | + to = DataType.from_str(mapping[to]) |
530 | 532 | except KeyError as err: |
531 | | - error_msg = "Expected instance of pyarrow.DataType or builtins.type" |
532 | | - raise TypeError(error_msg) from err |
| 533 | + msg = ( |
| 534 | + "Expected DataType, builtin type (float, int, str, bool), " |
| 535 | + "or object implementing __arrow_c_schema__" |
| 536 | + ) |
| 537 | + raise TypeError(msg) from err |
| 538 | + elif not isinstance(to, DataType): |
| 539 | + try: |
| 540 | + to = DataType.from_arrow_c_schema(to) |
| 541 | + except Exception as err: # pragma: no cover - type check |
| 542 | + msg = ( |
| 543 | + "Expected DataType, builtin type (float, int, str, bool), " |
| 544 | + "or object implementing __arrow_c_schema__" |
| 545 | + ) |
| 546 | + raise TypeError(msg) from err |
533 | 547 |
|
534 | 548 | return Expr(self.expr.cast(to)) |
535 | 549 |
|
|
0 commit comments