@@ -34,6 +34,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3434use datafusion:: datasource:: TableProvider ;
3535use datafusion:: error:: DataFusionError ;
3636use datafusion:: execution:: SendableRecordBatchStream ;
37+ use datafusion:: execution:: session_state:: SessionState ;
3738use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
3839use datafusion:: prelude:: * ;
3940use datafusion_ffi:: table_provider:: FFI_TableProvider ;
@@ -360,6 +361,8 @@ impl PyDataFrame {
360361/// converted via `record_batch_into_schema` to apply schema changes per batch.
361362struct PartitionedDataFrameStreamReader {
362363 streams : Vec < SendableRecordBatchStream > ,
364+ // Hold a reference to the session state to keep the context alive
365+ _state : Arc < SessionState > ,
363366 schema : SchemaRef ,
364367 projection : Option < SchemaRef > ,
365368 current : usize ,
@@ -951,6 +954,7 @@ impl PyDataFrame {
951954 requested_schema : Option < Bound < ' py , PyCapsule > > ,
952955 ) -> PyDataFusionResult < Bound < ' py , PyCapsule > > {
953956 let df = self . df . as_ref ( ) . clone ( ) ;
957+ let state = df. session_state ( ) . clone ( ) ;
954958 let streams = spawn_future ( py, async move { df. execute_stream_partitioned ( ) . await } ) ?;
955959
956960 let mut schema: Schema = self . df . schema ( ) . to_owned ( ) . into ( ) ;
@@ -969,6 +973,7 @@ impl PyDataFrame {
969973 let schema_ref = Arc :: new ( schema. clone ( ) ) ;
970974
971975 let reader = PartitionedDataFrameStreamReader {
976+ _state : state,
972977 streams,
973978 schema : schema_ref,
974979 projection,
@@ -985,14 +990,21 @@ impl PyDataFrame {
985990
986991 fn execute_stream ( & self , py : Python ) -> PyDataFusionResult < PyRecordBatchStream > {
987992 let df = self . df . as_ref ( ) . clone ( ) ;
993+ let state = df. session_state ( ) . clone ( ) ;
988994 let stream = spawn_future ( py, async move { df. execute_stream ( ) . await } ) ?;
989- Ok ( PyRecordBatchStream :: new ( stream) )
995+ Ok ( PyRecordBatchStream :: new ( stream, state ) )
990996 }
991997
992998 fn execute_stream_partitioned ( & self , py : Python ) -> PyResult < Vec < PyRecordBatchStream > > {
993999 let df = self . df . as_ref ( ) . clone ( ) ;
1000+ let state = df. session_state ( ) . clone ( ) ;
9941001 let streams = spawn_future ( py, async move { df. execute_stream_partitioned ( ) . await } ) ?;
995- Ok ( streams. into_iter ( ) . map ( PyRecordBatchStream :: new) . collect ( ) )
1002+ Ok (
1003+ streams
1004+ . into_iter ( )
1005+ . map ( |stream| PyRecordBatchStream :: new ( stream, state. clone ( ) ) )
1006+ . collect ( ) ,
1007+ )
9961008 }
9971009
9981010 /// Convert to pandas dataframe with pyarrow
0 commit comments