@@ -19,13 +19,13 @@ use std::collections::HashMap;
1919use std:: ffi:: CString ;
2020use std:: sync:: Arc ;
2121
22- use arrow:: array:: { new_null_array, RecordBatch , RecordBatchReader } ;
22+ use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
2323use arrow:: compute:: can_cast_types;
2424use arrow:: error:: ArrowError ;
2525use arrow:: ffi:: FFI_ArrowSchema ;
2626use arrow:: ffi_stream:: FFI_ArrowArrayStream ;
2727use arrow:: pyarrow:: FromPyArrow ;
28- use datafusion:: arrow:: datatypes:: { Schema , SchemaRef } ;
28+ use datafusion:: arrow:: datatypes:: Schema ;
2929use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
3030use datafusion:: arrow:: util:: pretty;
3131use datafusion:: common:: UnnestOptions ;
@@ -879,17 +879,8 @@ impl PyDataFrame {
879879 py : Python < ' py > ,
880880 requested_schema : Option < Bound < ' py , PyCapsule > > ,
881881 ) -> PyDataFusionResult < Bound < ' py , PyCapsule > > {
882- // execute query lazily using a stream
883- let rt = & get_tokio_runtime ( ) . 0 ;
884- let df = self . df . as_ref ( ) . clone ( ) ;
885- let fut: JoinHandle < datafusion:: common:: Result < SendableRecordBatchStream > > =
886- rt. spawn ( async move { df. execute_stream ( ) . await } ) ;
887- let stream = wait_for_future ( py, async { fut. await . map_err ( to_datafusion_err) } ) ???;
888-
889- // Determine the schema and handle optional projection
890- let stream_schema = stream. schema ( ) ;
891- let mut schema: Schema = stream_schema. as_ref ( ) . to_owned ( ) . into ( ) ;
892- let mut project = false ;
882+ let mut batches = wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . collect ( ) ) ??;
883+ let mut schema: Schema = self . df . schema ( ) . to_owned ( ) . into ( ) ;
893884
894885 if let Some ( schema_capsule) = requested_schema {
895886 validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
@@ -898,12 +889,17 @@ impl PyDataFrame {
898889 let desired_schema = Schema :: try_from ( schema_ptr) ?;
899890
900891 schema = project_schema ( schema, desired_schema) ?;
901- project = schema != * stream_schema. as_ref ( ) ;
892+
893+ batches = batches
894+ . into_iter ( )
895+ . map ( |record_batch| record_batch_into_schema ( record_batch, & schema) )
896+ . collect :: < Result < Vec < RecordBatch > , ArrowError > > ( ) ?;
902897 }
903898
904- let schema_ref: SchemaRef = Arc :: new ( schema) ;
905- let reader: Box < dyn RecordBatchReader + Send > =
906- Box :: new ( ArrowStreamReader :: new ( stream, schema_ref, project) ) ;
899+ let batches_wrapped = batches. into_iter ( ) . map ( Ok ) ;
900+
901+ let reader = RecordBatchIterator :: new ( batches_wrapped, Arc :: new ( schema) ) ;
902+ let reader: Box < dyn RecordBatchReader + Send > = Box :: new ( reader) ;
907903
908904 let ffi_stream = FFI_ArrowArrayStream :: new ( reader) ;
909905 let stream_capsule_name = CString :: new ( "arrow_array_stream" ) . unwrap ( ) ;
@@ -998,51 +994,6 @@ impl PyDataFrame {
998994 }
999995}
1000996
1001- struct ArrowStreamReader {
1002- stream : SendableRecordBatchStream ,
1003- schema : SchemaRef ,
1004- project : bool ,
1005- }
1006-
1007- impl ArrowStreamReader {
1008- fn new ( stream : SendableRecordBatchStream , schema : SchemaRef , project : bool ) -> Self {
1009- Self {
1010- stream,
1011- schema,
1012- project,
1013- }
1014- }
1015- }
1016-
1017- impl RecordBatchReader for ArrowStreamReader {
1018- fn schema ( & self ) -> SchemaRef {
1019- self . schema . clone ( )
1020- }
1021- }
1022-
1023- impl Iterator for ArrowStreamReader {
1024- type Item = Result < RecordBatch , ArrowError > ;
1025-
1026- fn next ( & mut self ) -> Option < Self :: Item > {
1027- let rt = & get_tokio_runtime ( ) . 0 ;
1028- match rt. block_on ( self . stream . next ( ) ) {
1029- Some ( Ok ( batch) ) => {
1030- let batch = if self . project {
1031- match record_batch_into_schema ( batch, self . schema . as_ref ( ) ) {
1032- Ok ( b) => b,
1033- Err ( e) => return Some ( Err ( e) ) ,
1034- }
1035- } else {
1036- batch
1037- } ;
1038- Some ( Ok ( batch) )
1039- }
1040- Some ( Err ( e) ) => Some ( Err ( ArrowError :: from ( e) ) ) ,
1041- None => None ,
1042- }
1043- }
1044- }
1045-
1046997/// Print DataFrame
1047998fn print_dataframe ( py : Python , df : DataFrame ) -> PyDataFusionResult < ( ) > {
1048999 // Get string representation of record batches
0 commit comments