Skip to content

Commit c21234c

Browse files
committed
Refactor DataFrame stream handling and add DataFrameStreamReader for synchronous iteration
1 parent 9ad8d2e commit c21234c

1 file changed

Lines changed: 60 additions & 17 deletions

File tree

src/dataframe.rs

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ use std::collections::HashMap;
1919
use std::ffi::CString;
2020
use std::sync::Arc;
2121

22-
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22+
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
2424
use arrow::error::ArrowError;
2525
use arrow::ffi::FFI_ArrowSchema;
2626
use arrow::ffi_stream::FFI_ArrowArrayStream;
2727
use arrow::pyarrow::FromPyArrow;
28-
use datafusion::arrow::datatypes::Schema;
28+
use datafusion::arrow::datatypes::{Schema, SchemaRef};
2929
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
3030
use datafusion::arrow::util::pretty;
3131
use datafusion::common::UnnestOptions;
@@ -42,7 +42,7 @@ use pyo3::exceptions::PyValueError;
4242
use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45-
use tokio::task::JoinHandle;
45+
use tokio::{runtime::Handle, task::JoinHandle};
4646

4747
use crate::catalog::PyTable;
4848
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
@@ -51,7 +51,8 @@ use crate::physical_plan::PyExecutionPlan;
5151
use crate::record_batch::PyRecordBatchStream;
5252
use crate::sql::logical::PyLogicalPlan;
5353
use crate::utils::{
54-
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
54+
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, validate_pycapsule,
55+
wait_for_future,
5556
};
5657
use crate::{
5758
errors::PyDataFusionResult,
@@ -353,6 +354,47 @@ impl PyDataFrame {
353354
Ok(html_str)
354355
}
355356
}
357+
/// Synchronous wrapper around a [`SendableRecordBatchStream`] used for
358+
/// the `__arrow_c_stream__` implementation.
359+
///
360+
/// It uses `runtime.block_on` to consume the underlying async stream,
361+
/// providing synchronous iteration. When a `projection` is set, each
362+
/// batch is converted via `record_batch_into_schema` to apply schema
363+
/// changes per batch.
364+
struct DataFrameStreamReader {
365+
stream: SendableRecordBatchStream,
366+
runtime: Handle,
367+
schema: SchemaRef,
368+
projection: Option<SchemaRef>,
369+
}
370+
371+
impl Iterator for DataFrameStreamReader {
372+
type Item = Result<RecordBatch, ArrowError>;
373+
374+
fn next(&mut self) -> Option<Self::Item> {
375+
match self.runtime.block_on(self.stream.next()) {
376+
Some(Ok(batch)) => {
377+
let batch = if let Some(ref schema) = self.projection {
378+
match record_batch_into_schema(batch, schema.as_ref()) {
379+
Ok(b) => b,
380+
Err(e) => return Some(Err(e)),
381+
}
382+
} else {
383+
batch
384+
};
385+
Some(Ok(batch))
386+
}
387+
Some(Err(e)) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
388+
None => None,
389+
}
390+
}
391+
}
392+
393+
impl RecordBatchReader for DataFrameStreamReader {
394+
fn schema(&self) -> SchemaRef {
395+
self.schema.clone()
396+
}
397+
}
356398

357399
#[pymethods]
358400
impl PyDataFrame {
@@ -879,8 +921,12 @@ impl PyDataFrame {
879921
py: Python<'py>,
880922
requested_schema: Option<Bound<'py, PyCapsule>>,
881923
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
882-
let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
924+
let rt = &get_tokio_runtime().0;
925+
let df = self.df.as_ref().clone();
926+
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
927+
883928
let mut schema: Schema = self.df.schema().to_owned().into();
929+
let mut projection: Option<SchemaRef> = None;
884930

885931
if let Some(schema_capsule) = requested_schema {
886932
validate_pycapsule(&schema_capsule, "arrow_schema")?;
@@ -889,16 +935,17 @@ impl PyDataFrame {
889935
let desired_schema = Schema::try_from(schema_ptr)?;
890936

891937
schema = project_schema(schema, desired_schema)?;
892-
893-
batches = batches
894-
.into_iter()
895-
.map(|record_batch| record_batch_into_schema(record_batch, &schema))
896-
.collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
938+
projection = Some(Arc::new(schema.clone()));
897939
}
898940

899-
let batches_wrapped = batches.into_iter().map(Ok);
941+
let schema_ref = projection.clone().unwrap_or_else(|| Arc::new(schema));
900942

901-
let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
943+
let reader = DataFrameStreamReader {
944+
stream,
945+
runtime: rt.handle().clone(),
946+
schema: schema_ref,
947+
projection,
948+
};
902949
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
903950

904951
let ffi_stream = FFI_ArrowArrayStream::new(reader);
@@ -907,12 +954,8 @@ impl PyDataFrame {
907954
}
908955

909956
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
910-
// create a Tokio runtime to run the async code
911-
let rt = &get_tokio_runtime().0;
912957
let df = self.df.as_ref().clone();
913-
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
914-
rt.spawn(async move { df.execute_stream().await });
915-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
958+
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
916959
Ok(PyRecordBatchStream::new(stream))
917960
}
918961

0 commit comments

Comments
 (0)