1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use datafusion:: physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
19- /// Implements a Datafusion physical ExecutionPlan that delegates to a PyArrow Dataset
20- /// This actually performs the projection, filtering and scanning of a Dataset
21- use pyo3:: prelude:: * ;
22- use pyo3:: types:: { PyDict , PyIterator , PyList } ;
23-
24- use std:: any:: Any ;
25- use std:: sync:: Arc ;
26-
27- use futures:: { stream, TryStreamExt } ;
28-
29- use datafusion:: arrow:: datatypes:: SchemaRef ;
30- use datafusion:: arrow:: error:: ArrowError ;
31- use datafusion:: arrow:: error:: Result as ArrowResult ;
32- use datafusion:: arrow:: pyarrow:: PyArrowType ;
33- use datafusion:: arrow:: record_batch:: RecordBatch ;
18+ use arrow:: array:: RecordBatchReader ;
19+ use arrow:: datatypes:: SchemaRef ;
20+ use arrow:: error:: { ArrowError , Result as ArrowResult } ;
21+ use arrow:: ffi_stream:: ArrowArrayStreamReader ;
22+ use arrow:: pyarrow:: FromPyArrow ;
23+ use arrow:: record_batch:: RecordBatch ;
3424use datafusion:: error:: { DataFusionError as InnerDataFusionError , Result as DFResult } ;
3525use datafusion:: execution:: context:: TaskContext ;
36- use datafusion:: logical_expr:: utils:: conjunction;
37- use datafusion:: logical_expr:: Expr ;
3826use datafusion:: physical_expr:: { EquivalenceProperties , LexOrdering } ;
27+ use datafusion:: physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
3928use datafusion:: physical_plan:: stream:: RecordBatchStreamAdapter ;
4029use datafusion:: physical_plan:: {
4130 DisplayAs , DisplayFormatType , ExecutionPlan , ExecutionPlanProperties , Partitioning ,
4231 SendableRecordBatchStream , Statistics ,
4332} ;
33+ use futures:: { stream, StreamExt } ;
34+ use pyo3:: prelude:: * ;
35+ use std:: any:: Any ;
36+ use std:: sync:: Arc ;
4437
4538use crate :: errors:: PyDataFusionResult ;
46- use crate :: pyarrow_filter_expression:: PyArrowFilterExpression ;
4739
48- struct PyArrowBatchesAdapter {
49- batches : Py < PyIterator > ,
40+ /// Iterator over an ArrowArrayStreamReader with optional projection
41+ struct ArrowCStreamAdapter {
42+ reader : ArrowArrayStreamReader ,
43+ projection : Option < Vec < usize > > ,
5044}
5145
52- impl Iterator for PyArrowBatchesAdapter {
46+ impl Iterator for ArrowCStreamAdapter {
5347 type Item = ArrowResult < RecordBatch > ;
5448
5549 fn next ( & mut self ) -> Option < Self :: Item > {
56- Python :: with_gil ( |py| {
57- let mut batches = self . batches . clone_ref ( py) . into_bound ( py) ;
58- Some (
59- batches
60- . next ( ) ?
61- . and_then ( |batch| Ok ( batch. extract :: < PyArrowType < _ > > ( ) ?. 0 ) )
62- . map_err ( |err| ArrowError :: ExternalError ( Box :: new ( err) ) ) ,
63- )
50+ self . reader . next ( ) . map ( |batch_res| {
51+ batch_res. and_then ( |batch| {
52+ if let Some ( indices) = & self . projection {
53+ batch
54+ . project ( indices)
55+ . map_err ( |e| ArrowError :: ExternalError ( Box :: new ( e) ) )
56+ } else {
57+ Ok ( batch)
58+ }
59+ } )
6460 } )
6561 }
6662}
6763
68- // Wraps a pyarrow.dataset.Dataset class and implements a Datafusion ExecutionPlan around it
64+ /// Execution plan that scans a Python object implementing ``__arrow_c_stream__``
6965#[ derive( Debug ) ]
7066pub ( crate ) struct DatasetExec {
7167 dataset : PyObject ,
7268 schema : SchemaRef ,
73- fragments : Py < PyList > ,
74- columns : Option < Vec < String > > ,
75- filter_expr : Option < PyObject > ,
69+ projection : Option < Vec < usize > > ,
7670 projected_statistics : Statistics ,
7771 plan_properties : datafusion:: physical_plan:: PlanProperties ,
7872}
7973
8074impl DatasetExec {
8175 pub fn new (
82- py : Python ,
8376 dataset : & Bound < ' _ , PyAny > ,
8477 projection : Option < Vec < usize > > ,
85- filters : & [ Expr ] ,
8678 ) -> PyDataFusionResult < Self > {
87- let columns: Option < PyDataFusionResult < Vec < String > > > = projection. map ( |p| {
88- p. iter ( )
89- . map ( |index| {
90- let name: String = dataset
91- . getattr ( "schema" ) ?
92- . call_method1 ( "field" , ( * index, ) ) ?
93- . getattr ( "name" ) ?
94- . extract ( ) ?;
95- Ok ( name)
96- } )
97- . collect ( )
98- } ) ;
99- let columns: Option < Vec < String > > = columns. transpose ( ) ?;
100- let filter_expr: Option < PyObject > = conjunction ( filters. to_owned ( ) )
101- . map ( |filters| {
102- PyArrowFilterExpression :: try_from ( & filters)
103- . map ( |filter_expr| filter_expr. inner ( ) . clone_ref ( py) )
104- } )
105- . transpose ( ) ?;
106-
107- let kwargs = PyDict :: new ( py) ;
108-
109- kwargs. set_item ( "columns" , columns. clone ( ) ) ?;
110- kwargs. set_item (
111- "filter" ,
112- filter_expr. as_ref ( ) . map ( |expr| expr. clone_ref ( py) ) ,
113- ) ?;
114-
115- let scanner = dataset. call_method ( "scanner" , ( ) , Some ( & kwargs) ) ?;
116-
117- let schema = Arc :: new (
118- scanner
119- . getattr ( "projected_schema" ) ?
120- . extract :: < PyArrowType < _ > > ( ) ?
121- . 0 ,
122- ) ;
123-
124- let builtins = Python :: import ( py, "builtins" ) ?;
125- let pylist = builtins. getattr ( "list" ) ?;
126-
127- // Get the fragments or partitions of the dataset
128- let fragments_iterator: Bound < ' _ , PyAny > = dataset. call_method1 (
129- "get_fragments" ,
130- ( filter_expr. as_ref ( ) . map ( |expr| expr. clone_ref ( py) ) , ) ,
131- ) ?;
132-
133- let fragments_iter = pylist. call1 ( ( fragments_iterator, ) ) ?;
134- let fragments = fragments_iter. downcast :: < PyList > ( ) . map_err ( PyErr :: from) ?;
79+ let reader = ArrowArrayStreamReader :: from_pyarrow_bound ( dataset) ?;
80+ let base_schema = reader. schema ( ) . as_ref ( ) . clone ( ) ;
81+ drop ( reader) ;
82+
83+ let projected_schema = if let Some ( ref proj) = projection {
84+ base_schema
85+ . project ( proj)
86+ . map_err ( |e| pyo3:: exceptions:: PyValueError :: new_err ( e. to_string ( ) ) ) ?
87+ } else {
88+ base_schema
89+ } ;
90+ let schema: SchemaRef = Arc :: new ( projected_schema) ;
13591
13692 let projected_statistics = Statistics :: new_unknown ( & schema) ;
13793 let plan_properties = datafusion:: physical_plan:: PlanProperties :: new (
13894 EquivalenceProperties :: new ( schema. clone ( ) ) ,
139- Partitioning :: UnknownPartitioning ( fragments . len ( ) ) ,
95+ Partitioning :: UnknownPartitioning ( 1 ) ,
14096 EmissionType :: Final ,
14197 Boundedness :: Bounded ,
14298 ) ;
14399
144100 Ok ( DatasetExec {
145101 dataset : dataset. clone ( ) . unbind ( ) ,
146102 schema,
147- fragments : fragments. clone ( ) . unbind ( ) ,
148- columns,
149- filter_expr,
103+ projection,
150104 projected_statistics,
151105 plan_properties,
152106 } )
@@ -155,22 +109,18 @@ impl DatasetExec {
155109
156110impl ExecutionPlan for DatasetExec {
157111 fn name ( & self ) -> & str {
158- // [ExecutionPlan::name] docs recommends forwarding to `static_name`
159112 Self :: static_name ( )
160113 }
161114
162- /// Return a reference to Any that can be used for downcasting
163115 fn as_any ( & self ) -> & dyn Any {
164116 self
165117 }
166118
167- /// Get the schema for this execution plan
168119 fn schema ( & self ) -> SchemaRef {
169120 self . schema . clone ( )
170121 }
171122
172123 fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
173- // this is a leaf node and has no children
174124 vec ! [ ]
175125 }
176126
@@ -184,56 +134,22 @@ impl ExecutionPlan for DatasetExec {
184134 fn execute (
185135 & self ,
186136 partition : usize ,
187- context : Arc < TaskContext > ,
137+ _context : Arc < TaskContext > ,
188138 ) -> DFResult < SendableRecordBatchStream > {
189- let batch_size = context. session_config ( ) . batch_size ( ) ;
139+ if partition != 0 {
140+ return Err ( InnerDataFusionError :: Plan ( "invalid partition" . to_string ( ) ) ) ;
141+ }
190142 Python :: with_gil ( |py| {
191143 let dataset = self . dataset . bind ( py) ;
192- let fragments = self . fragments . bind ( py) ;
193- let fragment = fragments
194- . get_item ( partition)
195- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
196-
197- // We need to pass the dataset schema to unify the fragment and dataset schema per PyArrow docs
198- let dataset_schema = dataset
199- . getattr ( "schema" )
200- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
201- let kwargs = PyDict :: new ( py) ;
202- kwargs
203- . set_item ( "columns" , self . columns . clone ( ) )
204- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
205- kwargs
206- . set_item (
207- "filter" ,
208- self . filter_expr . as_ref ( ) . map ( |expr| expr. clone_ref ( py) ) ,
209- )
144+ let reader = ArrowArrayStreamReader :: from_pyarrow_bound ( dataset)
210145 . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
211- kwargs
212- . set_item ( "batch_size" , batch_size)
213- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
214- let scanner = fragment
215- . call_method ( "scanner" , ( dataset_schema, ) , Some ( & kwargs) )
216- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
217- let schema: SchemaRef = Arc :: new (
218- scanner
219- . getattr ( "projected_schema" )
220- . and_then ( |schema| Ok ( schema. extract :: < PyArrowType < _ > > ( ) ?. 0 ) )
221- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?,
222- ) ;
223- let record_batches: Bound < ' _ , PyIterator > = scanner
224- . call_method0 ( "to_batches" )
225- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?
226- . try_iter ( )
227- . map_err ( |err| InnerDataFusionError :: External ( Box :: new ( err) ) ) ?;
228-
229- let record_batches = PyArrowBatchesAdapter {
230- batches : record_batches. into ( ) ,
146+ let adapter = ArrowCStreamAdapter {
147+ reader,
148+ projection : self . projection . clone ( ) ,
231149 } ;
232-
233- let record_batch_stream = stream:: iter ( record_batches) ;
234- let record_batch_stream: SendableRecordBatchStream = Box :: pin (
235- RecordBatchStreamAdapter :: new ( schema, record_batch_stream. map_err ( |e| e. into ( ) ) ) ,
236- ) ;
150+ let stream = stream:: iter ( adapter) . map ( |r| r. map_err ( |e| e. into ( ) ) ) ;
151+ let record_batch_stream: SendableRecordBatchStream =
152+ Box :: pin ( RecordBatchStreamAdapter :: new ( self . schema . clone ( ) , stream) ) ;
237153 Ok ( record_batch_stream)
238154 } )
239155 }
@@ -248,7 +164,6 @@ impl ExecutionPlan for DatasetExec {
248164}
249165
250166impl ExecutionPlanProperties for DatasetExec {
251- /// Get the output partitioning of this plan
252167 fn output_partitioning ( & self ) -> & Partitioning {
253168 self . plan_properties . output_partitioning ( )
254169 }
@@ -272,37 +187,22 @@ impl ExecutionPlanProperties for DatasetExec {
272187
273188impl DisplayAs for DatasetExec {
274189 fn fmt_as ( & self , t : DisplayFormatType , f : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
275- Python :: with_gil ( |py| {
276- let number_of_fragments = self . fragments . bind ( py) . len ( ) ;
277- match t {
278- DisplayFormatType :: Default
279- | DisplayFormatType :: Verbose
280- | DisplayFormatType :: TreeRender => {
281- let projected_columns: Vec < String > = self
282- . schema
283- . fields ( )
284- . iter ( )
285- . map ( |x| x. name ( ) . to_owned ( ) )
286- . collect ( ) ;
287- if let Some ( filter_expr) = & self . filter_expr {
288- let filter_expr = filter_expr. bind ( py) . str ( ) . or ( Err ( std:: fmt:: Error ) ) ?;
289- write ! (
290- f,
291- "DatasetExec: number_of_fragments={}, filter_expr={}, projection=[{}]" ,
292- number_of_fragments,
293- filter_expr,
294- projected_columns. join( ", " ) ,
295- )
296- } else {
297- write ! (
298- f,
299- "DatasetExec: number_of_fragments={}, projection=[{}]" ,
300- number_of_fragments,
301- projected_columns. join( ", " ) ,
302- )
303- }
304- }
190+ match t {
191+ DisplayFormatType :: Default
192+ | DisplayFormatType :: Verbose
193+ | DisplayFormatType :: TreeRender => {
194+ let projected_columns: Vec < String > = self
195+ . schema
196+ . fields ( )
197+ . iter ( )
198+ . map ( |x| x. name ( ) . to_owned ( ) )
199+ . collect ( ) ;
200+ write ! (
201+ f,
202+ "DatasetExec: projection=[{}]" ,
203+ projected_columns. join( ", " )
204+ )
305205 }
306- } )
206+ }
307207 }
308208}
0 commit comments