Skip to content

Commit 3f2204b

Browse files
committed
test: add tests for describe method output to PyList and Polars DataFrame
1 parent 632f614 commit 3f2204b

3 files changed

Lines changed: 40 additions & 10 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,6 +1902,21 @@ def test_describe_from_pydict(ctx):
19021902
}
19031903

19041904

1905+
def test_describe_to_pylist(df):
1906+
pylist = df.describe().to_pylist()
1907+
assert isinstance(pylist, list)
1908+
assert len(pylist) == 7
1909+
assert pylist[0]["describe"] == "count"
1910+
1911+
1912+
def test_describe_to_polars(df):
1913+
pl = pytest.importorskip("polars")
1914+
polars_df = df.describe().to_polars()
1915+
assert isinstance(polars_df, pl.DataFrame)
1916+
assert polars_df.shape == (7, 4)
1917+
assert set(polars_df.columns) == {"describe", "a", "b", "c"}
1918+
1919+
19051920
def test_describe_mixed_numeric_string():
19061921
ctx = SessionContext()
19071922
batch = pa.RecordBatch.from_arrays(

src/dataframe.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use crate::catalog::PyTable;
4747
use crate::errors::{py_datafusion_err, PyDataFusionError};
4848
use crate::expr::sort_expr::to_sort_expressions;
4949
use crate::physical_plan::PyExecutionPlan;
50-
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
50+
use crate::record_batch::{poll_next_batch, record_batches_to_pyarrow, PyRecordBatchStream};
5151
use crate::sql::logical::PyLogicalPlan;
5252
use crate::utils::{
5353
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule,
@@ -323,11 +323,9 @@ impl PyDataFrame {
323323

324324
let table_uuid = uuid::Uuid::new_v4().to_string();
325325

326-
// Convert record batches to PyObject list
327-
let py_batches = batches
328-
.iter()
329-
.map(|rb| rb.to_pyarrow(py))
330-
.collect::<PyResult<Vec<PyObject>>>()?;
326+
// Convert record batches to PyArrow objects, keeping the underlying
327+
// Rust data alive for the lifetime of the Python values.
328+
let py_batches = record_batches_to_pyarrow(batches.clone(), py)?;
331329

332330
let py_schema = self.schema().into_pyobject(py)?;
333331

@@ -581,9 +579,9 @@ impl PyDataFrame {
581579
fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
582580
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
583581
.map_err(PyDataFusionError::from)?;
584-
// cannot use PyResult<Vec<RecordBatch>> return type due to
585-
// https://github.com/PyO3/pyo3/issues/1813
586-
batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
582+
// Convert to PyArrow, ensuring the Rust-backed memory outlives
583+
// the Python objects.
584+
record_batches_to_pyarrow(batches, py)
587585
}
588586

589587
/// Cache DataFrame.
@@ -600,7 +598,7 @@ impl PyDataFrame {
600598

601599
batches
602600
.into_iter()
603-
.map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
601+
.map(|rbs| record_batches_to_pyarrow(rbs, py))
604602
.collect()
605603
}
606604

src/record_batch.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ pub struct PyRecordBatch {
3636
#[pymethods]
3737
impl PyRecordBatch {
3838
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
39+
// Use `ToPyArrow` to create a Python `RecordBatch` that
40+
// retains references to the underlying Rust data via the
41+
// Arrow C Data Interface. The returned object holds `Arc`
42+
// references so the Rust `RecordBatch` can be safely dropped
43+
// once the Python object exists.
3944
self.batch.to_pyarrow(py)
4045
}
4146
}
@@ -46,6 +51,18 @@ impl From<RecordBatch> for PyRecordBatch {
4651
}
4752
}
4853

54+
/// Convert a collection of [`RecordBatch`]es to Python objects while ensuring
55+
/// the Rust-backed memory remains alive for the lifetime of the Python
56+
/// objects.
57+
pub(crate) fn record_batches_to_pyarrow(
58+
batches: Vec<RecordBatch>,
59+
py: Python,
60+
) -> PyResult<Vec<PyObject>> {
61+
// Each call to `to_pyarrow` clones `Arc` pointers to the underlying arrays,
62+
// tying their lifetime to the resulting Python objects.
63+
batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
64+
}
65+
4966
#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
5067
pub struct PyRecordBatchStream {
5168
stream: Arc<Mutex<SendableRecordBatchStream>>,

0 commit comments

Comments
 (0)