Skip to content

Commit ce1cbcc

Browse files
committed
feat: enhance DataType and Expr to support objects implementing __arrow_c_schema__ without requiring PyArrow
1 parent 34a7e51 commit ce1cbcc

5 files changed

Lines changed: 80 additions & 29 deletions

File tree

docs/source/user-guide/common-operations/functions.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ Casting
109109

110110
Casting expressions to different data types using :py:func:`~datafusion.functions.arrow_cast`
111111

112+
DataFusion's :class:`~datafusion.types.DataType` can be constructed from any
113+
object implementing ``__arrow_c_schema__`` and passed to ``arrow_cast`` without
114+
requiring :mod:`pyarrow`.
115+
112116
.. ipython:: python
113117
114118
df.select(

python/datafusion/expr.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, Optional
2626

2727
import pyarrow as pa
2828

@@ -32,6 +32,7 @@
3232
from typing_extensions import deprecated # Python 3.12
3333

3434
from datafusion.common import DataTypeMap, NullTreatment, RexType
35+
from datafusion.types import DataType
3536

3637
from ._internal import expr as expr_internal
3738
from ._internal import functions as functions_internal
@@ -515,21 +516,34 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr:
515516
value = Expr.literal(value)
516517
return Expr(functions_internal.nvl(self.expr, value.expr))
517518

518-
_to_pyarrow_types: ClassVar[dict[type, pa.DataType]] = {
519-
float: pa.float64(),
520-
int: pa.int64(),
521-
str: pa.string(),
522-
bool: pa.bool_(),
523-
}
519+
def cast(self, to: DataType | type[float | int | str | bool] | object) -> Expr:
520+
"""Cast to a new data type.
524521
525-
def cast(self, to: pa.DataType[Any] | type[float | int | str | bool]) -> Expr:
526-
"""Cast to a new data type."""
527-
if not isinstance(to, pa.DataType):
522+
Args:
523+
to: Target type as :class:`datafusion.types.DataType`, a Python
524+
builtin (:class:`float`, :class:`int`, :class:`str`,
525+
:class:`bool`), or any object implementing
526+
``__arrow_c_schema__``.
527+
"""
528+
if isinstance(to, type):
529+
mapping = {float: "Float64", int: "Int64", str: "Utf8", bool: "Boolean"}
528530
try:
529-
to = self._to_pyarrow_types[to]
531+
to = DataType.from_str(mapping[to])
530532
except KeyError as err:
531-
error_msg = "Expected instance of pyarrow.DataType or builtins.type"
532-
raise TypeError(error_msg) from err
533+
msg = (
534+
"Expected DataType, builtin type (float, int, str, bool), "
535+
"or object implementing __arrow_c_schema__"
536+
)
537+
raise TypeError(msg) from err
538+
elif not isinstance(to, DataType):
539+
try:
540+
to = DataType.from_arrow_c_schema(to)
541+
except Exception as err: # pragma: no cover - type check
542+
msg = (
543+
"Expected DataType, builtin type (float, int, str, bool), "
544+
"or object implementing __arrow_c_schema__"
545+
)
546+
raise TypeError(msg) from err
533547

534548
return Expr(self.expr.cast(to))
535549

python/datafusion/types.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
"""Internal Arrow type helpers with optional PyArrow conversion."""
1+
"""Arrow type helpers with optional PyArrow conversion.
2+
3+
This module exposes :class:`datafusion.common.DataType`, a lightweight
4+
representation of Arrow types that can be created from any Python object
5+
implementing ``__arrow_c_schema__`` without requiring :mod:`pyarrow`.
6+
"""
27

38
from __future__ import annotations
49

510
from typing import Any
611

712
try: # pragma: no cover - optional dependency
813
import pyarrow as pa
9-
except Exception: # pragma: no cover - optional dependency
10-
pa = None # type: ignore
14+
except ModuleNotFoundError: # pragma: no cover - optional dependency
15+
pa = None # type: ignore[assignment]
1116

12-
from datafusion.common import DataTypeMap
17+
from datafusion.common import DataType, DataTypeMap
1318

1419
_PYARROW_TYPE_FACTORIES = {
1520
"Null": lambda: pa.null() if pa else None,
@@ -31,18 +36,17 @@
3136

3237
def pyarrow_available() -> bool:
3338
"""Return ``True`` if :mod:`pyarrow` can be imported."""
34-
3539
return pa is not None
3640

3741

38-
def to_pyarrow(data_type: DataTypeMap) -> "pa.DataType":
42+
def to_pyarrow(data_type: DataTypeMap) -> pa.DataType:
3943
"""Convert a :class:`DataTypeMap` to a :mod:`pyarrow` data type.
4044
4145
Raises ``ModuleNotFoundError`` if :mod:`pyarrow` is not installed.
4246
"""
43-
4447
if pa is None: # pragma: no cover - optional dependency
45-
raise ModuleNotFoundError("pyarrow is not installed")
48+
msg = "pyarrow is not installed"
49+
raise ModuleNotFoundError(msg)
4650
name = str(data_type.arrow_type)
4751
factory = _PYARROW_TYPE_FACTORIES.get(name)
4852
if factory is None:
@@ -51,14 +55,14 @@ def to_pyarrow(data_type: DataTypeMap) -> "pa.DataType":
5155
return factory()
5256

5357

54-
def from_pyarrow(pa_type: "pa.DataType") -> DataTypeMap:
58+
def from_pyarrow(pa_type: pa.DataType) -> DataTypeMap:
5559
"""Convert a :mod:`pyarrow` data type to :class:`DataTypeMap`.
5660
5761
Raises ``ModuleNotFoundError`` if :mod:`pyarrow` is not installed.
5862
"""
59-
6063
if pa is None: # pragma: no cover - optional dependency
61-
raise ModuleNotFoundError("pyarrow is not installed")
64+
msg = "pyarrow is not installed"
65+
raise ModuleNotFoundError(msg)
6266
return DataTypeMap.py_map_from_arrow_type_str(str(pa_type))
6367

6468

@@ -69,7 +73,16 @@ def ensure_pyarrow_type(value: DataTypeMap | Any) -> Any:
6973
it will be converted to the corresponding :mod:`pyarrow` data type.
7074
Otherwise ``value`` is returned unchanged.
7175
"""
72-
7376
if isinstance(value, DataTypeMap):
7477
return to_pyarrow(value) if pyarrow_available() else value
7578
return value
79+
80+
81+
__all__ = [
82+
"DataType",
83+
"DataTypeMap",
84+
"ensure_pyarrow_type",
85+
"from_pyarrow",
86+
"pyarrow_available",
87+
"to_pyarrow",
88+
]

src/common/data_type.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use datafusion::arrow::array::Array;
1919
use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
20+
use datafusion::arrow::pyarrow::PyArrowType;
2021
use datafusion::common::{DataFusionError, ScalarValue};
2122
use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment;
2223
use pyo3::{exceptions::PyValueError, prelude::*};
@@ -672,6 +673,23 @@ impl PyDataType {
672673
}
673674
}
674675

676+
#[pymethods]
677+
impl PyDataType {
678+
#[staticmethod]
679+
pub fn from_arrow_c_schema(obj: PyArrowType<DataType>) -> Self {
680+
PyDataType { data_type: obj.0 }
681+
}
682+
683+
#[staticmethod]
684+
pub fn from_str(arrow_str_type: &str) -> PyResult<Self> {
685+
Self::py_map_from_arrow_type_str(arrow_str_type.to_string())
686+
}
687+
688+
fn __repr__(&self) -> String {
689+
format!("{:?}", self.data_type)
690+
}
691+
}
692+
675693
impl From<PyDataType> for DataType {
676694
fn from(data_type: PyDataType) -> DataType {
677695
data_type.data_type

src/expr.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use std::sync::Arc;
2828
use window::PyWindowFrame;
2929

3030
use datafusion::arrow::datatypes::{DataType, Field};
31-
use datafusion::arrow::pyarrow::PyArrowType;
3231
use datafusion::functions::core::expr_ext::FieldAccessor;
3332
use datafusion::logical_expr::{
3433
col,
@@ -37,7 +36,7 @@ use datafusion::logical_expr::{
3736
};
3837

3938
use crate::arrow_ffi::scalar_to_pyarrow;
40-
use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType};
39+
use crate::common::data_type::{DataTypeMap, NullTreatment, PyDataType, PyScalarValue, RexType};
4140
use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, PyDataFusionResult};
4241
use crate::expr::aggregate_expr::PyAggregateFunction;
4342
use crate::expr::binary_expr::PyBinaryExpr;
@@ -313,10 +312,13 @@ impl PyExpr {
313312
self.expr.clone().is_not_null().into()
314313
}
315314

316-
pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
315+
pub fn cast(&self, to: PyDataType) -> PyExpr {
317316
// self.expr.cast_to() requires DFSchema to validate that the cast
318317
// is supported, omit that for now
319-
let expr = Expr::Cast(Cast::new(Box::new(self.expr.clone()), to.0));
318+
let expr = Expr::Cast(Cast::new(
319+
Box::new(self.expr.clone()),
320+
to.data_type,
321+
));
320322
expr.into()
321323
}
322324

0 commit comments

Comments
 (0)