Skip to content

Commit a5efa67

Browse files
committed
UNPICK revert Arrow streaming
1 parent 48fa874 commit a5efa67

3 files changed

Lines changed: 18 additions & 84 deletions

File tree

docs/source/user-guide/io/arrow.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ Exporting from DataFusion
5959

6060
DataFusion DataFrames implement ``__arrow_c_stream__`` PyCapsule interface, so any
6161
Python library that accepts these can import a DataFusion DataFrame directly.
62-
The exported stream yields record batches lazily using DataFusion's
63-
``execute_stream`` mechanism, allowing consumers to process results incrementally
64-
without buffering the entire dataset in memory. This streaming behavior helps
65-
avoid out-of-memory failures when working with large queries.
62+
63+
.. warning::
64+
It is important to note that this will cause the DataFrame execution to happen, which may be
65+
a time consuming task. That is, you will cause a
66+
:py:func:`datafusion.dataframe.DataFrame.collect` operation call to occur.
6667

6768

6869
.. ipython:: python

python/tests/test_dataframe.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import re
2121
import threading
2222
import time
23-
import tracemalloc
2423
from typing import Any
2524

2625
import pyarrow as pa
@@ -1568,23 +1567,6 @@ async def test_execute_stream_partitioned_async(df):
15681567
assert not remaining_batches
15691568

15701569

1571-
def test_arrow_c_stream_streaming(large_df):
1572-
df = large_df.repartition(4)
1573-
capsule = df.__arrow_c_stream__()
1574-
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
1575-
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
1576-
ptr = ctypes.pythonapi.PyCapsule_GetPointer(capsule, b"arrow_array_stream")
1577-
reader = pa.RecordBatchReader._import_from_c(ptr)
1578-
1579-
tracemalloc.start()
1580-
batch_count = sum(1 for _ in reader)
1581-
current, peak = tracemalloc.get_traced_memory()
1582-
tracemalloc.stop()
1583-
1584-
assert batch_count > 1
1585-
assert peak < 50 * MB
1586-
1587-
15881570
def test_empty_to_arrow_table(df):
15891571
# Convert empty datafusion dataframe to pyarrow Table
15901572
pyarrow_table = df.limit(0).to_arrow_table()

src/dataframe.rs

Lines changed: 13 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ use std::collections::HashMap;
1919
use std::ffi::CString;
2020
use std::sync::Arc;
2121

22-
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
22+
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
2424
use arrow::error::ArrowError;
2525
use arrow::ffi::FFI_ArrowSchema;
2626
use arrow::ffi_stream::FFI_ArrowArrayStream;
2727
use arrow::pyarrow::FromPyArrow;
28-
use datafusion::arrow::datatypes::{Schema, SchemaRef};
28+
use datafusion::arrow::datatypes::Schema;
2929
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
3030
use datafusion::arrow::util::pretty;
3131
use datafusion::common::UnnestOptions;
@@ -879,17 +879,8 @@ impl PyDataFrame {
879879
py: Python<'py>,
880880
requested_schema: Option<Bound<'py, PyCapsule>>,
881881
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
882-
// execute query lazily using a stream
883-
let rt = &get_tokio_runtime().0;
884-
let df = self.df.as_ref().clone();
885-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
886-
rt.spawn(async move { df.execute_stream().await });
887-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
888-
889-
// Determine the schema and handle optional projection
890-
let stream_schema = stream.schema();
891-
let mut schema: Schema = stream_schema.as_ref().to_owned().into();
892-
let mut project = false;
882+
let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
883+
let mut schema: Schema = self.df.schema().to_owned().into();
893884

894885
if let Some(schema_capsule) = requested_schema {
895886
validate_pycapsule(&schema_capsule, "arrow_schema")?;
@@ -898,12 +889,17 @@ impl PyDataFrame {
898889
let desired_schema = Schema::try_from(schema_ptr)?;
899890

900891
schema = project_schema(schema, desired_schema)?;
901-
project = schema != *stream_schema.as_ref();
892+
893+
batches = batches
894+
.into_iter()
895+
.map(|record_batch| record_batch_into_schema(record_batch, &schema))
896+
.collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
902897
}
903898

904-
let schema_ref: SchemaRef = Arc::new(schema);
905-
let reader: Box<dyn RecordBatchReader + Send> =
906-
Box::new(ArrowStreamReader::new(stream, schema_ref, project));
899+
let batches_wrapped = batches.into_iter().map(Ok);
900+
901+
let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
902+
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
907903

908904
let ffi_stream = FFI_ArrowArrayStream::new(reader);
909905
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
@@ -998,51 +994,6 @@ impl PyDataFrame {
998994
}
999995
}
1000996

1001-
struct ArrowStreamReader {
1002-
stream: SendableRecordBatchStream,
1003-
schema: SchemaRef,
1004-
project: bool,
1005-
}
1006-
1007-
impl ArrowStreamReader {
1008-
fn new(stream: SendableRecordBatchStream, schema: SchemaRef, project: bool) -> Self {
1009-
Self {
1010-
stream,
1011-
schema,
1012-
project,
1013-
}
1014-
}
1015-
}
1016-
1017-
impl RecordBatchReader for ArrowStreamReader {
1018-
fn schema(&self) -> SchemaRef {
1019-
self.schema.clone()
1020-
}
1021-
}
1022-
1023-
impl Iterator for ArrowStreamReader {
1024-
type Item = Result<RecordBatch, ArrowError>;
1025-
1026-
fn next(&mut self) -> Option<Self::Item> {
1027-
let rt = &get_tokio_runtime().0;
1028-
match rt.block_on(self.stream.next()) {
1029-
Some(Ok(batch)) => {
1030-
let batch = if self.project {
1031-
match record_batch_into_schema(batch, self.schema.as_ref()) {
1032-
Ok(b) => b,
1033-
Err(e) => return Some(Err(e)),
1034-
}
1035-
} else {
1036-
batch
1037-
};
1038-
Some(Ok(batch))
1039-
}
1040-
Some(Err(e)) => Some(Err(ArrowError::from(e))),
1041-
None => None,
1042-
}
1043-
}
1044-
}
1045-
1046997
/// Print DataFrame
1047998
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
1048999
// Get string representation of record batches

0 commit comments

Comments
 (0)