Skip to content

Commit 34a7e51

Browse files
committed
feat: enhance RecordBatch and DataFrame methods for improved PyArrow compatibility
1 parent 1171d44 commit 34a7e51

3 files changed

Lines changed: 74 additions & 37 deletions

File tree

python/datafusion/dataframe.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def _repr_html_(self) -> str:
332332

333333
@staticmethod
334334
def default_str_repr(
335-
batches: list[pa.RecordBatch],
336-
schema: pa.Schema,
335+
batches: list[RecordBatch],
336+
schema: "pa.Schema",
337337
has_more: bool,
338338
table_uuid: str | None = None,
339339
) -> str:
@@ -342,7 +342,13 @@ def default_str_repr(
342342
This method is used by the default formatter and implemented in Rust for
343343
performance reasons.
344344
"""
345-
return DataFrameInternal.default_str_repr(batches, schema, has_more, table_uuid)
345+
import pyarrow as pa
346+
347+
py_batches = [b.to_pyarrow() for b in batches]
348+
schema = pa.schema(schema)
349+
return DataFrameInternal.default_str_repr(
350+
py_batches, schema, has_more, table_uuid
351+
)
346352

347353
def describe(self) -> DataFrame:
348354
"""Return the statistics for this DataFrame.
@@ -589,17 +595,17 @@ def tail(self, n: int = 5) -> DataFrame:
589595
"""
590596
return DataFrame(self.df.limit(n, max(0, self.count() - n)))
591597

592-
def collect(self) -> list[pa.RecordBatch]:
598+
def collect(self) -> list[RecordBatch]:
593599
"""Execute this :py:class:`DataFrame` and collect results into memory.
594600
595-
Prior to calling ``collect``, modifying a DataFrme simply updates a plan
601+
Prior to calling ``collect``, modifying a DataFrame simply updates a plan
596602
(no actual computation is performed). Calling ``collect`` triggers the
597603
computation.
598604
599605
Returns:
600-
List of :py:class:`pyarrow.RecordBatch` collected from the DataFrame.
606+
List of :py:class:`RecordBatch` collected from the DataFrame.
601607
"""
602-
return self.df.collect()
608+
return [RecordBatch(rb) for rb in self.df.collect()]
603609

604610
def cache(self) -> DataFrame:
605611
"""Cache the DataFrame as a memory table.
@@ -609,17 +615,19 @@ def cache(self) -> DataFrame:
609615
"""
610616
return DataFrame(self.df.cache())
611617

612-
def collect_partitioned(self) -> list[list[pa.RecordBatch]]:
618+
def collect_partitioned(self) -> list[list[RecordBatch]]:
613619
"""Execute this DataFrame and collect all partitioned results.
614620
615-
This operation returns :py:class:`pyarrow.RecordBatch` maintaining the input
621+
This operation returns :py:class:`RecordBatch` maintaining the input
616622
partitioning.
617623
618624
Returns:
619625
List of list of :py:class:`RecordBatch` collected from the
620626
DataFrame.
621627
"""
622-
return self.df.collect_partitioned()
628+
return [
629+
[RecordBatch(rb) for rb in rbs] for rbs in self.df.collect_partitioned()
630+
]
623631

624632
def show(self, num: int = 20) -> None:
625633
"""Execute the DataFrame and print the result to the console.
@@ -1047,13 +1055,15 @@ def execute_stream_partitioned(self) -> list[RecordBatchStream]:
10471055
streams = self.df.execute_stream_partitioned()
10481056
return [RecordBatchStream(rbs) for rbs in streams]
10491057

1050-
def to_pandas(self) -> pd.DataFrame:
1051-
"""Execute the :py:class:`DataFrame` and convert it into a Pandas DataFrame.
1058+
def to_pandas(self) -> "pd.DataFrame":
1059+
"""Execute the :py:class:`DataFrame` and convert it into a Pandas DataFrame."""
10521060

1053-
Returns:
1054-
Pandas DataFrame.
1055-
"""
1056-
return self.df.to_pandas()
1061+
import pandas as pd
1062+
import pyarrow as pa
1063+
1064+
batches = [rb.to_pyarrow() for rb in self.collect()]
1065+
table = pa.Table.from_batches(batches)
1066+
return table.to_pandas()
10571067

10581068
def to_pylist(self) -> list[dict[str, Any]]:
10591069
"""Execute the :py:class:`DataFrame` and convert it into a list of dictionaries.
@@ -1071,13 +1081,15 @@ def to_pydict(self) -> dict[str, list[Any]]:
10711081
"""
10721082
return self.df.to_pydict()
10731083

1074-
def to_polars(self) -> pl.DataFrame:
1075-
"""Execute the :py:class:`DataFrame` and convert it into a Polars DataFrame.
1084+
def to_polars(self) -> "pl.DataFrame":
1085+
"""Execute the :py:class:`DataFrame` and convert it into a Polars DataFrame."""
10761086

1077-
Returns:
1078-
Polars DataFrame.
1079-
"""
1080-
return self.df.to_polars()
1087+
import polars as pl
1088+
import pyarrow as pa
1089+
1090+
batches = [rb.to_pyarrow() for rb in self.collect()]
1091+
table = pa.Table.from_batches(batches)
1092+
return pl.from_arrow(table)
10811093

10821094
def count(self) -> int:
10831095
"""Return the total number of rows in this :py:class:`DataFrame`.

python/datafusion/record_batch.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from __future__ import annotations
2525

26-
from typing import TYPE_CHECKING
26+
from typing import TYPE_CHECKING, Any
2727

2828
if TYPE_CHECKING:
2929
import pyarrow as pa
@@ -33,25 +33,52 @@
3333

3434

3535
class RecordBatch:
36-
"""This class is essentially a wrapper for :py:class:`pa.RecordBatch`."""
36+
"""Wrapper around project-defined ``RecordBatch`` with optional PyArrow support."""
3737

3838
def __init__(self, record_batch: df_internal.RecordBatch) -> None:
3939
"""This constructor is generally not called by the end user.
4040
4141
See the :py:class:`RecordBatchStream` iterator for generating this class.
4242
"""
4343
self.record_batch = record_batch
44+
self._pyarrow_rb: pa.RecordBatch | None = None
4445

4546
def to_pyarrow(self) -> pa.RecordBatch:
46-
"""Convert to :py:class:`pa.RecordBatch`."""
47-
return self.record_batch.to_pyarrow()
47+
"""Convert to :py:class:`pa.RecordBatch`.
48+
49+
Requires :mod:`pyarrow` to be installed.
50+
"""
51+
if self._pyarrow_rb is None:
52+
self._pyarrow_rb = self.record_batch.to_pyarrow()
53+
return self._pyarrow_rb
4854

4955
def __arrow_c_array__(
5056
self, requested_schema: object | None = None
5157
) -> tuple[object, object]:
5258
"""Arrow C Data Interface export."""
5359
return self.record_batch.__arrow_c_array__(requested_schema)
5460

61+
# ------------------------------------------------------------------
62+
# PyArrow compatibility helpers
63+
# ------------------------------------------------------------------
64+
def __getattr__(self, item: str) -> Any: # pragma: no cover - simple delegation
65+
"""Delegate attribute access to the PyArrow ``RecordBatch``."""
66+
return getattr(self.to_pyarrow(), item)
67+
68+
def __getitem__(self, key) -> Any: # pragma: no cover - simple delegation
69+
"""Delegate item access to the PyArrow ``RecordBatch``."""
70+
return self.to_pyarrow()[key]
71+
72+
def __len__(self) -> int: # pragma: no cover - simple delegation
73+
"""Delegate ``len`` to the PyArrow ``RecordBatch``."""
74+
return len(self.to_pyarrow())
75+
76+
def __eq__(self, other: object) -> bool: # pragma: no cover - simple delegation
77+
"""Compare equality using the underlying PyArrow ``RecordBatch``."""
78+
if not isinstance(other, RecordBatch):
79+
return NotImplemented
80+
return self.to_pyarrow().equals(other.to_pyarrow())
81+
5582

5683
class RecordBatchStream:
5784
"""This class represents a stream of record batches.

src/dataframe.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use crate::catalog::PyTable;
4747
use crate::errors::{py_datafusion_err, PyDataFusionError};
4848
use crate::expr::sort_expr::to_sort_expressions;
4949
use crate::physical_plan::PyExecutionPlan;
50-
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
50+
use crate::record_batch::{poll_next_batch, PyRecordBatch, PyRecordBatchStream};
5151
use crate::sql::logical::PyLogicalPlan;
5252
use crate::utils::{
5353
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule,
@@ -582,12 +582,10 @@ impl PyDataFrame {
582582
/// Executes the plan, returning a list of `RecordBatch`es.
583583
/// Unless some order is specified in the plan, there is no
584584
/// guarantee of the order of the result.
585-
fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
585+
fn collect(&self, py: Python) -> PyDataFusionResult<Vec<PyRecordBatch>> {
586586
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
587587
.map_err(PyDataFusionError::from)?;
588-
// cannot use PyResult<Vec<RecordBatch>> return type due to
589-
// https://github.com/PyO3/pyo3/issues/1813
590-
batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
588+
Ok(batches.into_iter().map(PyRecordBatch::from).collect())
591589
}
592590

593591
/// Cache DataFrame.
@@ -596,16 +594,16 @@ impl PyDataFrame {
596594
Ok(Self::new(df))
597595
}
598596

599-
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
600-
/// maintaining the input partitioning.
601-
fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
597+
/// Executes this DataFrame and collects all results into a vector of vector of
598+
/// `RecordBatch`, maintaining the input partitioning.
599+
fn collect_partitioned(&self, py: Python) -> PyDataFusionResult<Vec<Vec<PyRecordBatch>>> {
602600
let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
603601
.map_err(PyDataFusionError::from)?;
604602

605-
batches
603+
Ok(batches
606604
.into_iter()
607-
.map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
608-
.collect()
605+
.map(|rbs| rbs.into_iter().map(PyRecordBatch::from).collect())
606+
.collect())
609607
}
610608

611609
/// Print the result, 20 lines by default

0 commit comments

Comments
 (0)