Skip to content

Commit e621b64

Browse files
committed
UNPICK implement Arrow streaming
This reverts commit a5efa67.
1 parent a5efa67 commit e621b64

3 files changed

Lines changed: 84 additions & 18 deletions

File tree

docs/source/user-guide/io/arrow.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,10 @@ Exporting from DataFusion
5959

6060
DataFusion DataFrames implement ``__arrow_c_stream__`` PyCapsule interface, so any
6161
Python library that accepts these can import a DataFusion DataFrame directly.
62-
63-
.. warning::
64-
It is important to note that this will cause the DataFrame execution to happen, which may be
65-
a time consuming task. That is, you will cause a
66-
:py:func:`datafusion.dataframe.DataFrame.collect` operation call to occur.
62+
The exported stream yields record batches lazily using DataFusion's
63+
``execute_stream`` mechanism, allowing consumers to process results incrementally
64+
without buffering the entire dataset in memory. This streaming behavior helps
65+
avoid out-of-memory failures when working with large queries.
6766

6867

6968
.. ipython:: python

python/tests/test_dataframe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import re
2121
import threading
2222
import time
23+
import tracemalloc
2324
from typing import Any
2425

2526
import pyarrow as pa
@@ -1567,6 +1568,23 @@ async def test_execute_stream_partitioned_async(df):
15671568
assert not remaining_batches
15681569

15691570

1571+
def test_arrow_c_stream_streaming(large_df):
1572+
df = large_df.repartition(4)
1573+
capsule = df.__arrow_c_stream__()
1574+
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
1575+
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
1576+
ptr = ctypes.pythonapi.PyCapsule_GetPointer(capsule, b"arrow_array_stream")
1577+
reader = pa.RecordBatchReader._import_from_c(ptr)
1578+
1579+
tracemalloc.start()
1580+
batch_count = sum(1 for _ in reader)
1581+
current, peak = tracemalloc.get_traced_memory()
1582+
tracemalloc.stop()
1583+
1584+
assert batch_count > 1
1585+
assert peak < 50 * MB
1586+
1587+
15701588
def test_empty_to_arrow_table(df):
15711589
# Convert empty datafusion dataframe to pyarrow Table
15721590
pyarrow_table = df.limit(0).to_arrow_table()

src/dataframe.rs

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ use std::collections::HashMap;
1919
use std::ffi::CString;
2020
use std::sync::Arc;
2121

22-
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22+
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
2424
use arrow::error::ArrowError;
2525
use arrow::ffi::FFI_ArrowSchema;
2626
use arrow::ffi_stream::FFI_ArrowArrayStream;
2727
use arrow::pyarrow::FromPyArrow;
28-
use datafusion::arrow::datatypes::Schema;
28+
use datafusion::arrow::datatypes::{Schema, SchemaRef};
2929
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
3030
use datafusion::arrow::util::pretty;
3131
use datafusion::common::UnnestOptions;
@@ -879,8 +879,17 @@ impl PyDataFrame {
879879
py: Python<'py>,
880880
requested_schema: Option<Bound<'py, PyCapsule>>,
881881
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
882-
let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
883-
let mut schema: Schema = self.df.schema().to_owned().into();
882+
// execute query lazily using a stream
883+
let rt = &get_tokio_runtime().0;
884+
let df = self.df.as_ref().clone();
885+
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
886+
rt.spawn(async move { df.execute_stream().await });
887+
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
888+
889+
// Determine the schema and handle optional projection
890+
let stream_schema = stream.schema();
891+
let mut schema: Schema = stream_schema.as_ref().to_owned().into();
892+
let mut project = false;
884893

885894
if let Some(schema_capsule) = requested_schema {
886895
validate_pycapsule(&schema_capsule, "arrow_schema")?;
@@ -889,17 +898,12 @@ impl PyDataFrame {
889898
let desired_schema = Schema::try_from(schema_ptr)?;
890899

891900
schema = project_schema(schema, desired_schema)?;
892-
893-
batches = batches
894-
.into_iter()
895-
.map(|record_batch| record_batch_into_schema(record_batch, &schema))
896-
.collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
901+
project = schema != *stream_schema.as_ref();
897902
}
898903

899-
let batches_wrapped = batches.into_iter().map(Ok);
900-
901-
let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
902-
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
904+
let schema_ref: SchemaRef = Arc::new(schema);
905+
let reader: Box<dyn RecordBatchReader + Send> =
906+
Box::new(ArrowStreamReader::new(stream, schema_ref, project));
903907

904908
let ffi_stream = FFI_ArrowArrayStream::new(reader);
905909
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
@@ -994,6 +998,51 @@ impl PyDataFrame {
994998
}
995999
}
9961000

1001+
struct ArrowStreamReader {
1002+
stream: SendableRecordBatchStream,
1003+
schema: SchemaRef,
1004+
project: bool,
1005+
}
1006+
1007+
impl ArrowStreamReader {
1008+
fn new(stream: SendableRecordBatchStream, schema: SchemaRef, project: bool) -> Self {
1009+
Self {
1010+
stream,
1011+
schema,
1012+
project,
1013+
}
1014+
}
1015+
}
1016+
1017+
impl RecordBatchReader for ArrowStreamReader {
1018+
fn schema(&self) -> SchemaRef {
1019+
self.schema.clone()
1020+
}
1021+
}
1022+
1023+
impl Iterator for ArrowStreamReader {
1024+
type Item = Result<RecordBatch, ArrowError>;
1025+
1026+
fn next(&mut self) -> Option<Self::Item> {
1027+
let rt = &get_tokio_runtime().0;
1028+
match rt.block_on(self.stream.next()) {
1029+
Some(Ok(batch)) => {
1030+
let batch = if self.project {
1031+
match record_batch_into_schema(batch, self.schema.as_ref()) {
1032+
Ok(b) => b,
1033+
Err(e) => return Some(Err(e)),
1034+
}
1035+
} else {
1036+
batch
1037+
};
1038+
Some(Ok(batch))
1039+
}
1040+
Some(Err(e)) => Some(Err(ArrowError::from(e))),
1041+
None => None,
1042+
}
1043+
}
1044+
}
1045+
9971046
/// Print DataFrame
9981047
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
9991048
// Get string representation of record batches

0 commit comments

Comments
 (0)