Skip to content

Commit c5cd149

Browse files
committed
Enhance Arrow C stream export to support partitioned reading, reducing memory usage for large result sets
1 parent d05a410 commit c5cd149

4 files changed

Lines changed: 81 additions & 31 deletions

File tree

docs/source/user-guide/io/arrow.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ output incrementally:
7979
for batch in reader:
8080
... # process each batch without buffering the entire table
8181
82+
DataFusion reads one partition at a time when exporting a C stream, so large
83+
result sets are not buffered entirely in memory.
84+
8285
If the goal is simply to persist results, prefer engine-level writers such as
8386
``df.write_parquet()``. These writers stream data from Rust directly to the
8487
destination and avoid Python-side memory growth.

python/datafusion/dataframe.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,10 +1100,12 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram
11001100
def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
11011101
"""Export an Arrow PyCapsule Stream.
11021102
1103-
This will execute and collect the DataFrame. We will attempt to respect the
1104-
requested schema, but only trivial transformations will be applied such as only
1105-
returning the fields listed in the requested schema if their data types match
1106-
those in the DataFrame.
1103+
This executes the query lazily and returns a capsule backed by a
1104+
partition-aware reader. It will respect the requested schema when
1105+
possible, but only trivial transformations are applied such as returning
1106+
only the fields listed in the requested schema if their data types match
1107+
those in the DataFrame. Batches are yielded one partition at a time so
1108+
results are not buffered entirely in memory.
11071109
11081110
Args:
11091111
requested_schema: Attempt to provide the DataFrame using this schema.

python/tests/test_record_batch_stream.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,35 @@ async def test_record_batch_stream_anext(ctx):
3333
assert batch.to_pyarrow().num_rows == 1
3434
with pytest.raises(StopAsyncIteration):
3535
await stream.__anext__()
36+
37+
38+
def test_arrow_c_stream_partitioned(tmp_path, ctx):
39+
import gc
40+
41+
import pyarrow as pa
42+
import pyarrow.parquet as pq
43+
44+
num_parts = 5
45+
rows_per_part = 100_000
46+
arr = pa.array(range(rows_per_part), pa.int64())
47+
batch = pa.RecordBatch.from_arrays([arr], names=["a"])
48+
table = pa.Table.from_batches([batch])
49+
for i in range(num_parts):
50+
pq.write_table(table, tmp_path / f"part{i}.parquet")
51+
52+
del arr, batch, table
53+
gc.collect()
54+
55+
df = ctx.read_parquet(str(tmp_path))
56+
capsule = df.__arrow_c_stream__()
57+
reader = pa.ipc.RecordBatchStreamReader._import_from_c(capsule)
58+
59+
pool = pa.default_memory_pool()
60+
baseline = pool.bytes_allocated()
61+
peak = baseline
62+
for b in reader:
63+
peak = max(peak, pool.bytes_allocated())
64+
del b
65+
gc.collect()
66+
67+
assert peak - baseline < rows_per_part * 8 * 2

src/dataframe.rs

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -943,12 +943,12 @@ impl PyDataFrame {
943943
py: Python<'py>,
944944
requested_schema: Option<Bound<'py, PyCapsule>>,
945945
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
946-
// execute query lazily using a stream
946+
// execute query lazily using a stream per partition
947947
let df = self.df.as_ref().clone();
948-
let stream = spawn_and_wait(py, async move { df.execute_stream().await })?;
948+
let streams = spawn_and_wait(py, async move { df.execute_stream_partitioned().await })?;
949949

950950
// Determine the schema and handle optional projection
951-
let stream_schema = stream.schema();
951+
let stream_schema = streams[0].schema();
952952
let mut schema: Schema = stream_schema.as_ref().to_owned();
953953
let mut project = false;
954954

@@ -963,8 +963,9 @@ impl PyDataFrame {
963963
}
964964

965965
let schema_ref: SchemaRef = Arc::new(schema);
966-
let reader: Box<dyn RecordBatchReader + Send> =
967-
Box::new(ArrowStreamReader::new(stream, schema_ref, project));
966+
let reader: Box<dyn RecordBatchReader + Send> = Box::new(
967+
PartitionedArrowStreamReader::new(streams, schema_ref, project),
968+
);
968969

969970
let ffi_stream = FFI_ArrowArrayStream::new(reader);
970971
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
@@ -1049,48 +1050,60 @@ impl PyDataFrame {
10491050
}
10501051
}
10511052

1052-
struct ArrowStreamReader {
1053-
stream: SendableRecordBatchStream,
1053+
struct PartitionedArrowStreamReader {
1054+
streams: Vec<SendableRecordBatchStream>,
10541055
schema: SchemaRef,
10551056
project: bool,
1057+
current: usize,
10561058
}
10571059

1058-
impl ArrowStreamReader {
1059-
fn new(stream: SendableRecordBatchStream, schema: SchemaRef, project: bool) -> Self {
1060+
impl PartitionedArrowStreamReader {
1061+
fn new(streams: Vec<SendableRecordBatchStream>, schema: SchemaRef, project: bool) -> Self {
10601062
Self {
1061-
stream,
1063+
streams,
10621064
schema,
10631065
project,
1066+
current: 0,
10641067
}
10651068
}
10661069
}
10671070

1068-
impl RecordBatchReader for ArrowStreamReader {
1071+
impl RecordBatchReader for PartitionedArrowStreamReader {
10691072
fn schema(&self) -> SchemaRef {
10701073
self.schema.clone()
10711074
}
10721075
}
10731076

1074-
impl Iterator for ArrowStreamReader {
1077+
impl Iterator for PartitionedArrowStreamReader {
10751078
type Item = Result<RecordBatch, ArrowError>;
10761079

10771080
fn next(&mut self) -> Option<Self::Item> {
1078-
let result = Python::with_gil(|py| wait_for_stream_next(py, &mut self.stream));
1079-
1080-
match result {
1081-
Ok(Some(batch)) => {
1082-
let batch = if self.project {
1083-
match record_batch_into_schema(batch, self.schema.as_ref()) {
1084-
Ok(b) => b,
1085-
Err(e) => return Some(Err(e)),
1086-
}
1087-
} else {
1088-
batch
1089-
};
1090-
Some(Ok(batch))
1081+
loop {
1082+
if self.current >= self.streams.len() {
1083+
return None;
1084+
}
1085+
1086+
let stream = &mut self.streams[self.current];
1087+
let result = Python::with_gil(|py| wait_for_stream_next(py, stream));
1088+
1089+
match result {
1090+
Ok(Some(batch)) => {
1091+
let batch = if self.project {
1092+
match record_batch_into_schema(batch, self.schema.as_ref()) {
1093+
Ok(b) => b,
1094+
Err(e) => return Some(Err(e)),
1095+
}
1096+
} else {
1097+
batch
1098+
};
1099+
return Some(Ok(batch));
1100+
}
1101+
Ok(None) => {
1102+
self.current += 1;
1103+
continue;
1104+
}
1105+
Err(e) => return Some(Err(ArrowError::from(e))),
10911106
}
1092-
Ok(None) => None,
1093-
Err(e) => Some(Err(ArrowError::from(e))),
10941107
}
10951108
}
10961109
}

0 commit comments

Comments
 (0)