Skip to content

Commit 4b57c7a

Browse files
committed
feat: refactor DatasetExec to utilize ArrowArrayStreamReader and improve projection handling
1 parent 4db9962 commit 4b57c7a

1 file changed

Lines changed: 69 additions & 169 deletions

File tree

src/dataset_exec.rs

Lines changed: 69 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -15,138 +15,92 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
19-
/// Implements a Datafusion physical ExecutionPlan that delegates to a PyArrow Dataset
20-
/// This actually performs the projection, filtering and scanning of a Dataset
21-
use pyo3::prelude::*;
22-
use pyo3::types::{PyDict, PyIterator, PyList};
23-
24-
use std::any::Any;
25-
use std::sync::Arc;
26-
27-
use futures::{stream, TryStreamExt};
28-
29-
use datafusion::arrow::datatypes::SchemaRef;
30-
use datafusion::arrow::error::ArrowError;
31-
use datafusion::arrow::error::Result as ArrowResult;
32-
use datafusion::arrow::pyarrow::PyArrowType;
33-
use datafusion::arrow::record_batch::RecordBatch;
18+
use arrow::array::RecordBatchReader;
19+
use arrow::datatypes::SchemaRef;
20+
use arrow::error::{ArrowError, Result as ArrowResult};
21+
use arrow::ffi_stream::ArrowArrayStreamReader;
22+
use arrow::pyarrow::FromPyArrow;
23+
use arrow::record_batch::RecordBatch;
3424
use datafusion::error::{DataFusionError as InnerDataFusionError, Result as DFResult};
3525
use datafusion::execution::context::TaskContext;
36-
use datafusion::logical_expr::utils::conjunction;
37-
use datafusion::logical_expr::Expr;
3826
use datafusion::physical_expr::{EquivalenceProperties, LexOrdering};
27+
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
3928
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
4029
use datafusion::physical_plan::{
4130
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
4231
SendableRecordBatchStream, Statistics,
4332
};
33+
use futures::{stream, StreamExt};
34+
use pyo3::prelude::*;
35+
use std::any::Any;
36+
use std::sync::Arc;
4437

4538
use crate::errors::PyDataFusionResult;
46-
use crate::pyarrow_filter_expression::PyArrowFilterExpression;
4739

48-
struct PyArrowBatchesAdapter {
49-
batches: Py<PyIterator>,
40+
/// Iterator over an ArrowArrayStreamReader with optional projection
41+
struct ArrowCStreamAdapter {
42+
reader: ArrowArrayStreamReader,
43+
projection: Option<Vec<usize>>,
5044
}
5145

52-
impl Iterator for PyArrowBatchesAdapter {
46+
impl Iterator for ArrowCStreamAdapter {
5347
type Item = ArrowResult<RecordBatch>;
5448

5549
fn next(&mut self) -> Option<Self::Item> {
56-
Python::with_gil(|py| {
57-
let mut batches = self.batches.clone_ref(py).into_bound(py);
58-
Some(
59-
batches
60-
.next()?
61-
.and_then(|batch| Ok(batch.extract::<PyArrowType<_>>()?.0))
62-
.map_err(|err| ArrowError::ExternalError(Box::new(err))),
63-
)
50+
self.reader.next().map(|batch_res| {
51+
batch_res.and_then(|batch| {
52+
if let Some(indices) = &self.projection {
53+
batch
54+
.project(indices)
55+
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
56+
} else {
57+
Ok(batch)
58+
}
59+
})
6460
})
6561
}
6662
}
6763

68-
// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion ExecutionPlan around it
64+
/// Execution plan that scans a Python object implementing ``__arrow_c_stream__``
6965
#[derive(Debug)]
7066
pub(crate) struct DatasetExec {
7167
dataset: PyObject,
7268
schema: SchemaRef,
73-
fragments: Py<PyList>,
74-
columns: Option<Vec<String>>,
75-
filter_expr: Option<PyObject>,
69+
projection: Option<Vec<usize>>,
7670
projected_statistics: Statistics,
7771
plan_properties: datafusion::physical_plan::PlanProperties,
7872
}
7973

8074
impl DatasetExec {
8175
pub fn new(
82-
py: Python,
8376
dataset: &Bound<'_, PyAny>,
8477
projection: Option<Vec<usize>>,
85-
filters: &[Expr],
8678
) -> PyDataFusionResult<Self> {
87-
let columns: Option<PyDataFusionResult<Vec<String>>> = projection.map(|p| {
88-
p.iter()
89-
.map(|index| {
90-
let name: String = dataset
91-
.getattr("schema")?
92-
.call_method1("field", (*index,))?
93-
.getattr("name")?
94-
.extract()?;
95-
Ok(name)
96-
})
97-
.collect()
98-
});
99-
let columns: Option<Vec<String>> = columns.transpose()?;
100-
let filter_expr: Option<PyObject> = conjunction(filters.to_owned())
101-
.map(|filters| {
102-
PyArrowFilterExpression::try_from(&filters)
103-
.map(|filter_expr| filter_expr.inner().clone_ref(py))
104-
})
105-
.transpose()?;
106-
107-
let kwargs = PyDict::new(py);
108-
109-
kwargs.set_item("columns", columns.clone())?;
110-
kwargs.set_item(
111-
"filter",
112-
filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
113-
)?;
114-
115-
let scanner = dataset.call_method("scanner", (), Some(&kwargs))?;
116-
117-
let schema = Arc::new(
118-
scanner
119-
.getattr("projected_schema")?
120-
.extract::<PyArrowType<_>>()?
121-
.0,
122-
);
123-
124-
let builtins = Python::import(py, "builtins")?;
125-
let pylist = builtins.getattr("list")?;
126-
127-
// Get the fragments or partitions of the dataset
128-
let fragments_iterator: Bound<'_, PyAny> = dataset.call_method1(
129-
"get_fragments",
130-
(filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
131-
)?;
132-
133-
let fragments_iter = pylist.call1((fragments_iterator,))?;
134-
let fragments = fragments_iter.downcast::<PyList>().map_err(PyErr::from)?;
79+
let reader = ArrowArrayStreamReader::from_pyarrow_bound(dataset)?;
80+
let base_schema = reader.schema().as_ref().clone();
81+
drop(reader);
82+
83+
let projected_schema = if let Some(ref proj) = projection {
84+
base_schema
85+
.project(proj)
86+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?
87+
} else {
88+
base_schema
89+
};
90+
let schema: SchemaRef = Arc::new(projected_schema);
13591

13692
let projected_statistics = Statistics::new_unknown(&schema);
13793
let plan_properties = datafusion::physical_plan::PlanProperties::new(
13894
EquivalenceProperties::new(schema.clone()),
139-
Partitioning::UnknownPartitioning(fragments.len()),
95+
Partitioning::UnknownPartitioning(1),
14096
EmissionType::Final,
14197
Boundedness::Bounded,
14298
);
14399

144100
Ok(DatasetExec {
145101
dataset: dataset.clone().unbind(),
146102
schema,
147-
fragments: fragments.clone().unbind(),
148-
columns,
149-
filter_expr,
103+
projection,
150104
projected_statistics,
151105
plan_properties,
152106
})
@@ -155,22 +109,18 @@ impl DatasetExec {
155109

156110
impl ExecutionPlan for DatasetExec {
157111
fn name(&self) -> &str {
158-
// [ExecutionPlan::name] docs recommends forwarding to `static_name`
159112
Self::static_name()
160113
}
161114

162-
/// Return a reference to Any that can be used for downcasting
163115
fn as_any(&self) -> &dyn Any {
164116
self
165117
}
166118

167-
/// Get the schema for this execution plan
168119
fn schema(&self) -> SchemaRef {
169120
self.schema.clone()
170121
}
171122

172123
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
173-
// this is a leaf node and has no children
174124
vec![]
175125
}
176126

@@ -184,56 +134,22 @@ impl ExecutionPlan for DatasetExec {
184134
fn execute(
185135
&self,
186136
partition: usize,
187-
context: Arc<TaskContext>,
137+
_context: Arc<TaskContext>,
188138
) -> DFResult<SendableRecordBatchStream> {
189-
let batch_size = context.session_config().batch_size();
139+
if partition != 0 {
140+
return Err(InnerDataFusionError::Plan("invalid partition".to_string()));
141+
}
190142
Python::with_gil(|py| {
191143
let dataset = self.dataset.bind(py);
192-
let fragments = self.fragments.bind(py);
193-
let fragment = fragments
194-
.get_item(partition)
195-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
196-
197-
// We need to pass the dataset schema to unify the fragment and dataset schema per PyArrow docs
198-
let dataset_schema = dataset
199-
.getattr("schema")
200-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
201-
let kwargs = PyDict::new(py);
202-
kwargs
203-
.set_item("columns", self.columns.clone())
204-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
205-
kwargs
206-
.set_item(
207-
"filter",
208-
self.filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
209-
)
144+
let reader = ArrowArrayStreamReader::from_pyarrow_bound(dataset)
210145
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
211-
kwargs
212-
.set_item("batch_size", batch_size)
213-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
214-
let scanner = fragment
215-
.call_method("scanner", (dataset_schema,), Some(&kwargs))
216-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
217-
let schema: SchemaRef = Arc::new(
218-
scanner
219-
.getattr("projected_schema")
220-
.and_then(|schema| Ok(schema.extract::<PyArrowType<_>>()?.0))
221-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?,
222-
);
223-
let record_batches: Bound<'_, PyIterator> = scanner
224-
.call_method0("to_batches")
225-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?
226-
.try_iter()
227-
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
228-
229-
let record_batches = PyArrowBatchesAdapter {
230-
batches: record_batches.into(),
146+
let adapter = ArrowCStreamAdapter {
147+
reader,
148+
projection: self.projection.clone(),
231149
};
232-
233-
let record_batch_stream = stream::iter(record_batches);
234-
let record_batch_stream: SendableRecordBatchStream = Box::pin(
235-
RecordBatchStreamAdapter::new(schema, record_batch_stream.map_err(|e| e.into())),
236-
);
150+
let stream = stream::iter(adapter).map(|r| r.map_err(|e| e.into()));
151+
let record_batch_stream: SendableRecordBatchStream =
152+
Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream));
237153
Ok(record_batch_stream)
238154
})
239155
}
@@ -248,7 +164,6 @@ impl ExecutionPlan for DatasetExec {
248164
}
249165

250166
impl ExecutionPlanProperties for DatasetExec {
251-
/// Get the output partitioning of this plan
252167
fn output_partitioning(&self) -> &Partitioning {
253168
self.plan_properties.output_partitioning()
254169
}
@@ -272,37 +187,22 @@ impl ExecutionPlanProperties for DatasetExec {
272187

273188
impl DisplayAs for DatasetExec {
274189
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
275-
Python::with_gil(|py| {
276-
let number_of_fragments = self.fragments.bind(py).len();
277-
match t {
278-
DisplayFormatType::Default
279-
| DisplayFormatType::Verbose
280-
| DisplayFormatType::TreeRender => {
281-
let projected_columns: Vec<String> = self
282-
.schema
283-
.fields()
284-
.iter()
285-
.map(|x| x.name().to_owned())
286-
.collect();
287-
if let Some(filter_expr) = &self.filter_expr {
288-
let filter_expr = filter_expr.bind(py).str().or(Err(std::fmt::Error))?;
289-
write!(
290-
f,
291-
"DatasetExec: number_of_fragments={}, filter_expr={}, projection=[{}]",
292-
number_of_fragments,
293-
filter_expr,
294-
projected_columns.join(", "),
295-
)
296-
} else {
297-
write!(
298-
f,
299-
"DatasetExec: number_of_fragments={}, projection=[{}]",
300-
number_of_fragments,
301-
projected_columns.join(", "),
302-
)
303-
}
304-
}
190+
match t {
191+
DisplayFormatType::Default
192+
| DisplayFormatType::Verbose
193+
| DisplayFormatType::TreeRender => {
194+
let projected_columns: Vec<String> = self
195+
.schema
196+
.fields()
197+
.iter()
198+
.map(|x| x.name().to_owned())
199+
.collect();
200+
write!(
201+
f,
202+
"DatasetExec: projection=[{}]",
203+
projected_columns.join(", ")
204+
)
305205
}
306-
})
206+
}
307207
}
308208
}

0 commit comments

Comments
 (0)