@@ -19,13 +19,13 @@ use std::collections::HashMap;
1919use std:: ffi:: CString ;
2020use std:: sync:: Arc ;
2121
22- use arrow:: array:: { new_null_array, RecordBatch , RecordBatchReader } ;
22+ use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
2323use arrow:: compute:: can_cast_types;
2424use arrow:: error:: ArrowError ;
2525use arrow:: ffi:: FFI_ArrowSchema ;
2626use arrow:: ffi_stream:: FFI_ArrowArrayStream ;
2727use arrow:: pyarrow:: FromPyArrow ;
28- use datafusion:: arrow:: datatypes:: { Schema , SchemaRef } ;
28+ use datafusion:: arrow:: datatypes:: Schema ;
2929use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
3030use datafusion:: arrow:: util:: pretty;
3131use datafusion:: common:: UnnestOptions ;
@@ -42,7 +42,7 @@ use pyo3::exceptions::PyValueError;
4242use pyo3:: prelude:: * ;
4343use pyo3:: pybacked:: PyBackedStr ;
4444use pyo3:: types:: { PyCapsule , PyList , PyTuple , PyTupleMethods } ;
45- use tokio:: { runtime :: Handle , task:: JoinHandle } ;
45+ use tokio:: task:: JoinHandle ;
4646
4747use crate :: catalog:: PyTable ;
4848use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError } ;
@@ -51,8 +51,7 @@ use crate::physical_plan::PyExecutionPlan;
5151use crate :: record_batch:: PyRecordBatchStream ;
5252use crate :: sql:: logical:: PyLogicalPlan ;
5353use crate :: utils:: {
54- get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, validate_pycapsule,
55- wait_for_future,
54+ get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
5655} ;
5756use crate :: {
5857 errors:: PyDataFusionResult ,
@@ -354,47 +353,6 @@ impl PyDataFrame {
354353 Ok ( html_str)
355354 }
356355}
357- /// Synchronous wrapper around a [`SendableRecordBatchStream`] used for
358- /// the `__arrow_c_stream__` implementation.
359- ///
360- /// It uses `runtime.block_on` to consume the underlying async stream,
361- /// providing synchronous iteration. When a `projection` is set, each
362- /// batch is converted via `record_batch_into_schema` to apply schema
363- /// changes per batch.
364- struct DataFrameStreamReader {
365- stream : SendableRecordBatchStream ,
366- runtime : Handle ,
367- schema : SchemaRef ,
368- projection : Option < SchemaRef > ,
369- }
370-
371- impl Iterator for DataFrameStreamReader {
372- type Item = Result < RecordBatch , ArrowError > ;
373-
374- fn next ( & mut self ) -> Option < Self :: Item > {
375- match self . runtime . block_on ( self . stream . next ( ) ) {
376- Some ( Ok ( batch) ) => {
377- let batch = if let Some ( ref schema) = self . projection {
378- match record_batch_into_schema ( batch, schema. as_ref ( ) ) {
379- Ok ( b) => b,
380- Err ( e) => return Some ( Err ( e) ) ,
381- }
382- } else {
383- batch
384- } ;
385- Some ( Ok ( batch) )
386- }
387- Some ( Err ( e) ) => Some ( Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ) ,
388- None => None ,
389- }
390- }
391- }
392-
393- impl RecordBatchReader for DataFrameStreamReader {
394- fn schema ( & self ) -> SchemaRef {
395- self . schema . clone ( )
396- }
397- }
398356
399357#[ pymethods]
400358impl PyDataFrame {
@@ -921,12 +879,8 @@ impl PyDataFrame {
921879 py : Python < ' py > ,
922880 requested_schema : Option < Bound < ' py , PyCapsule > > ,
923881 ) -> PyDataFusionResult < Bound < ' py , PyCapsule > > {
924- let rt = & get_tokio_runtime ( ) . 0 ;
925- let df = self . df . as_ref ( ) . clone ( ) ;
926- let stream = spawn_stream ( py, async move { df. execute_stream ( ) . await } ) ?;
927-
882+ let mut batches = wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . collect ( ) ) ??;
928883 let mut schema: Schema = self . df . schema ( ) . to_owned ( ) . into ( ) ;
929- let mut projection: Option < SchemaRef > = None ;
930884
931885 if let Some ( schema_capsule) = requested_schema {
932886 validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
@@ -935,17 +889,16 @@ impl PyDataFrame {
935889 let desired_schema = Schema :: try_from ( schema_ptr) ?;
936890
937891 schema = project_schema ( schema, desired_schema) ?;
938- projection = Some ( Arc :: new ( schema. clone ( ) ) ) ;
892+
893+ batches = batches
894+ . into_iter ( )
895+ . map ( |record_batch| record_batch_into_schema ( record_batch, & schema) )
896+ . collect :: < Result < Vec < RecordBatch > , ArrowError > > ( ) ?;
939897 }
940898
941- let schema_ref = projection . clone ( ) . unwrap_or_else ( || Arc :: new ( schema ) ) ;
899+ let batches_wrapped = batches . into_iter ( ) . map ( Ok ) ;
942900
943- let reader = DataFrameStreamReader {
944- stream,
945- runtime : rt. handle ( ) . clone ( ) ,
946- schema : schema_ref,
947- projection,
948- } ;
901+ let reader = RecordBatchIterator :: new ( batches_wrapped, Arc :: new ( schema) ) ;
949902 let reader: Box < dyn RecordBatchReader + Send > = Box :: new ( reader) ;
950903
951904 let ffi_stream = FFI_ArrowArrayStream :: new ( reader) ;
@@ -954,8 +907,12 @@ impl PyDataFrame {
954907 }
955908
956909 fn execute_stream ( & self , py : Python ) -> PyDataFusionResult < PyRecordBatchStream > {
910+ // create a Tokio runtime to run the async code
911+ let rt = & get_tokio_runtime ( ) . 0 ;
957912 let df = self . df . as_ref ( ) . clone ( ) ;
958- let stream = spawn_stream ( py, async move { df. execute_stream ( ) . await } ) ?;
913+ let fut: JoinHandle < datafusion:: common:: Result < SendableRecordBatchStream > > =
914+ rt. spawn ( async move { df. execute_stream ( ) . await } ) ;
915+ let stream = wait_for_future ( py, async { fut. await . map_err ( to_datafusion_err) } ) ???;
959916 Ok ( PyRecordBatchStream :: new ( stream) )
960917 }
961918
0 commit comments