1818use std:: collections:: HashMap ;
1919use std:: ffi:: CString ;
2020use std:: sync:: Arc ;
21+ use tokio:: sync:: Mutex ;
2122
2223use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
2324use arrow:: compute:: can_cast_types;
@@ -33,8 +34,8 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table
3334use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
3435use datafusion:: datasource:: TableProvider ;
3536use datafusion:: error:: DataFusionError ;
36- use datafusion:: execution:: SendableRecordBatchStream ;
3737use datafusion:: execution:: session_state:: SessionState ;
38+ use datafusion:: execution:: SendableRecordBatchStream ;
3839use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
3940use datafusion:: prelude:: * ;
4041use datafusion_ffi:: table_provider:: FFI_TableProvider ;
@@ -360,7 +361,7 @@ impl PyDataFrame {
360361/// their original partition order. When a `projection` is set, each batch is
361362/// converted via `record_batch_into_schema` to apply schema changes per batch.
362363struct PartitionedDataFrameStreamReader {
363- streams : Vec < SendableRecordBatchStream > ,
364+ streams : Vec < Arc < Mutex < SendableRecordBatchStream > > > ,
364365 // Hold a reference to the session state to keep the context alive
365366 _state : Arc < SessionState > ,
366367 schema : SchemaRef ,
@@ -373,12 +374,17 @@ impl Iterator for PartitionedDataFrameStreamReader {
373374
374375 fn next ( & mut self ) -> Option < Self :: Item > {
375376 while self . current < self . streams . len ( ) {
376- let stream = & mut self . streams [ self . current ] ;
377- let fut = poll_next_batch ( stream) ;
378- let result = Python :: with_gil ( |py| wait_for_future ( py, fut) ) ;
377+ let stream = self . streams [ self . current ] . clone ( ) ;
378+
379+ let result = Python :: with_gil ( |py| {
380+ spawn_future ( py, async move {
381+ let mut s = stream. lock ( ) . await ;
382+ poll_next_batch ( & mut s) . await
383+ } )
384+ } ) ;
379385
380386 match result {
381- Ok ( Ok ( Some ( batch) ) ) => {
387+ Ok ( Some ( batch) ) => {
382388 let batch = if let Some ( ref schema) = self . projection {
383389 match record_batch_into_schema ( batch, schema. as_ref ( ) ) {
384390 Ok ( b) => b,
@@ -389,13 +395,10 @@ impl Iterator for PartitionedDataFrameStreamReader {
389395 } ;
390396 return Some ( Ok ( batch) ) ;
391397 }
392- Ok ( Ok ( None ) ) => {
398+ Ok ( None ) => {
393399 self . current += 1 ;
394400 continue ;
395401 }
396- Ok ( Err ( e) ) => {
397- return Some ( Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ) ;
398- }
399402 Err ( e) => {
400403 return Some ( Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ) ;
401404 }
@@ -956,6 +959,10 @@ impl PyDataFrame {
956959 let df = self . df . as_ref ( ) . clone ( ) ;
957960 let state = df. session_state ( ) . clone ( ) ;
958961 let streams = spawn_future ( py, async move { df. execute_stream_partitioned ( ) . await } ) ?;
962+ let streams = streams
963+ . into_iter ( )
964+ . map ( |s| Arc :: new ( Mutex :: new ( s) ) )
965+ . collect ( ) ;
959966
960967 let mut schema: Schema = self . df . schema ( ) . to_owned ( ) . into ( ) ;
961968 let mut projection: Option < SchemaRef > = None ;
@@ -999,12 +1006,10 @@ impl PyDataFrame {
9991006 let df = self . df . as_ref ( ) . clone ( ) ;
10001007 let state = df. session_state ( ) . clone ( ) ;
10011008 let streams = spawn_future ( py, async move { df. execute_stream_partitioned ( ) . await } ) ?;
1002- Ok (
1003- streams
1004- . into_iter ( )
1005- . map ( |stream| PyRecordBatchStream :: new ( stream, state. clone ( ) ) )
1006- . collect ( ) ,
1007- )
1009+ Ok ( streams
1010+ . into_iter ( )
1011+ . map ( |stream| PyRecordBatchStream :: new ( stream, state. clone ( ) ) )
1012+ . collect ( ) )
10081013 }
10091014
10101015 /// Convert to pandas dataframe with pyarrow
0 commit comments