Skip to content

Commit 4db9962

Browse files
committed
feat: enhance Dataset to support any Python object implementing __arrow_c_stream__ interface
1 parent 2bed9b0 commit 4db9962

1 file changed

Lines changed: 19 additions & 29 deletions

File tree

src/dataset.rs

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,46 @@
1717

1818
use datafusion::catalog::Session;
1919
use pyo3::exceptions::PyValueError;
20-
/// Implements a Datafusion TableProvider that delegates to a PyArrow Dataset
21-
/// This allows us to use PyArrow Datasets as Datafusion tables while pushing down projections and filters
20+
/// Implements a Datafusion TableProvider that delegates to any Python object
21+
/// implementing the ``__arrow_c_stream__`` interface. This allows DataFusion to
22+
/// scan objects from libraries such as PyArrow, nanoarrow, Polars, or DuckDB
23+
/// without requiring a concrete ``pyarrow.dataset.Dataset`` instance.
2224
use pyo3::prelude::*;
23-
use pyo3::types::PyType;
2425

2526
use std::any::Any;
2627
use std::sync::Arc;
2728

2829
use async_trait::async_trait;
2930

31+
use arrow::array::RecordBatchReader;
32+
use arrow::ffi_stream::ArrowArrayStreamReader;
33+
use arrow::pyarrow::FromPyArrow;
3034
use datafusion::arrow::datatypes::SchemaRef;
31-
use datafusion::arrow::pyarrow::PyArrowType;
3235
use datafusion::datasource::{TableProvider, TableType};
3336
use datafusion::error::{DataFusionError, Result as DFResult};
3437
use datafusion::logical_expr::Expr;
3538
use datafusion::logical_expr::TableProviderFilterPushDown;
3639
use datafusion::physical_plan::ExecutionPlan;
3740

3841
use crate::dataset_exec::DatasetExec;
39-
use crate::pyarrow_filter_expression::PyArrowFilterExpression;
4042

41-
// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion TableProvider around it
43+
// Wraps a Python object implementing ``__arrow_c_stream__`` and exposes it as a
44+
// DataFusion TableProvider.
4245
#[derive(Debug)]
4346
pub(crate) struct Dataset {
4447
dataset: PyObject,
4548
}
4649

4750
impl Dataset {
48-
// Creates a Python PyArrow.Dataset
49-
pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult<Self> {
50-
// Ensure that we were passed an instance of pyarrow.dataset.Dataset
51-
let ds = PyModule::import(py, "pyarrow.dataset")?;
52-
let ds_attr = ds.getattr("Dataset")?;
53-
let ds_type = ds_attr.downcast::<PyType>()?;
54-
if dataset.is_instance(ds_type)? {
51+
// Creates a Python Dataset wrapper
52+
pub fn new(dataset: &Bound<'_, PyAny>, _py: Python) -> PyResult<Self> {
53+
if dataset.hasattr("__arrow_c_stream__")? {
5554
Ok(Dataset {
5655
dataset: dataset.clone().unbind(),
5756
})
5857
} else {
5958
Err(PyValueError::new_err(
60-
"dataset argument must be a pyarrow.dataset.Dataset object",
59+
"dataset argument must implement __arrow_c_stream__",
6160
))
6261
}
6362
}
@@ -75,15 +74,9 @@ impl TableProvider for Dataset {
7574
fn schema(&self) -> SchemaRef {
7675
Python::with_gil(|py| {
7776
let dataset = self.dataset.bind(py);
78-
// This can panic but since we checked that self.dataset is a pyarrow.dataset.Dataset it should never
79-
Arc::new(
80-
dataset
81-
.getattr("schema")
82-
.unwrap()
83-
.extract::<PyArrowType<_>>()
84-
.unwrap()
85-
.0,
86-
)
77+
let reader = ArrowArrayStreamReader::from_pyarrow_bound(dataset)
78+
.expect("dataset must implement __arrow_c_stream__");
79+
Arc::new(reader.schema().as_ref().clone())
8780
})
8881
}
8982

@@ -100,7 +93,7 @@ impl TableProvider for Dataset {
10093
&self,
10194
_ctx: &dyn Session,
10295
projection: Option<&Vec<usize>>,
103-
filters: &[Expr],
96+
_filters: &[Expr],
10497
// limit can be used to reduce the amount scanned
10598
// from the datasource as a performance optimization.
10699
// If set, it contains the amount of rows needed by the `LogicalPlan`,
@@ -109,7 +102,7 @@ impl TableProvider for Dataset {
109102
) -> DFResult<Arc<dyn ExecutionPlan>> {
110103
Python::with_gil(|py| {
111104
let plan: Arc<dyn ExecutionPlan> = Arc::new(
112-
DatasetExec::new(py, self.dataset.bind(py), projection.cloned(), filters)
105+
DatasetExec::new(self.dataset.bind(py), projection.cloned())
113106
.map_err(|err| DataFusionError::External(Box::new(err)))?,
114107
);
115108
Ok(plan)
@@ -124,10 +117,7 @@ impl TableProvider for Dataset {
124117
) -> DFResult<Vec<TableProviderFilterPushDown>> {
125118
filter
126119
.iter()
127-
.map(|&f| match PyArrowFilterExpression::try_from(f) {
128-
Ok(_) => Ok(TableProviderFilterPushDown::Exact),
129-
_ => Ok(TableProviderFilterPushDown::Unsupported),
130-
})
120+
.map(|_| Ok(TableProviderFilterPushDown::Unsupported))
131121
.collect()
132122
}
133123
}

0 commit comments

Comments
 (0)