Skip to content

Commit 6552d43

Browse files
committed
feat: refactor table provider registration and introduce pyany_to_table_provider utility
1 parent 6f34ba1 commit 6552d43

4 files changed

Lines changed: 41 additions & 81 deletions

File tree

src/catalog.rs

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::dataframe::PyTableProvider;
1918
use crate::dataset::Dataset;
2019
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
20+
use crate::table::pyany_to_table_provider;
2121
use crate::utils::{validate_pycapsule, wait_for_future};
2222
use async_trait::async_trait;
2323
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
@@ -28,7 +28,6 @@ use datafusion::{
2828
datasource::{TableProvider, TableType},
2929
};
3030
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
31-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
3231
use pyo3::exceptions::PyKeyError;
3332
use pyo3::prelude::*;
3433
use pyo3::types::PyCapsule;
@@ -197,29 +196,7 @@ impl PySchema {
197196
}
198197

199198
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
200-
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
201-
let capsule = table_provider
202-
.getattr("__datafusion_table_provider__")?
203-
.call0()?;
204-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
205-
validate_pycapsule(capsule, "datafusion_table_provider")?;
206-
207-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
208-
let provider: ForeignTableProvider = provider.into();
209-
Arc::new(provider) as Arc<dyn TableProvider + Send>
210-
} else {
211-
match table_provider.extract::<PyTable>() {
212-
Ok(py_table) => py_table.table,
213-
Err(_) => match table_provider.extract::<PyTableProvider>() {
214-
Ok(py_provider) => py_provider.as_table().table(),
215-
Err(_) => {
216-
let py = table_provider.py();
217-
let provider = Dataset::new(&table_provider, py)?;
218-
Arc::new(provider) as Arc<dyn TableProvider + Send>
219-
}
220-
},
221-
}
222-
};
199+
let provider = pyany_to_table_provider(&table_provider)?;
223200

224201
let _ = self
225202
.schema
@@ -308,34 +285,7 @@ impl RustWrappedPySchemaProvider {
308285
return Ok(None);
309286
}
310287

311-
if py_table.hasattr("__datafusion_table_provider__")? {
312-
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
313-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
314-
validate_pycapsule(capsule, "datafusion_table_provider")?;
315-
316-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
317-
let provider: ForeignTableProvider = provider.into();
318-
319-
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider + Send>))
320-
} else {
321-
if let Ok(inner_table) = py_table.getattr("table") {
322-
if let Ok(inner_table) = inner_table.extract::<PyTable>() {
323-
return Ok(Some(inner_table.table));
324-
}
325-
}
326-
327-
if let Ok(py_provider) = py_table.extract::<PyTableProvider>() {
328-
return Ok(Some(py_provider.as_table().table()));
329-
}
330-
331-
match py_table.extract::<PyTable>() {
332-
Ok(py_table) => Ok(Some(py_table.table)),
333-
Err(_) => {
334-
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
335-
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider + Send>))
336-
}
337-
}
338-
}
288+
pyany_to_table_provider(&py_table).map(Some)
339289
})
340290
}
341291
}

src/context.rs

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ use pyo3::prelude::*;
3333

3434
use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider};
3535
use crate::dataframe::PyDataFrame;
36-
use crate::dataframe::PyTableProvider;
3736
use crate::dataset::Dataset;
38-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
37+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
3938
use crate::expr::sort_expr::PySortExpr;
4039
use crate::physical_plan::PyExecutionPlan;
4140
use crate::record_batch::PyRecordBatchStream;
4241
use crate::sql::exceptions::py_value_err;
4342
use crate::sql::logical::PyLogicalPlan;
4443
use crate::store::StorageContexts;
44+
use crate::table::pyany_to_table_provider;
4545
use crate::udaf::PyAggregateUDF;
4646
use crate::udf::PyScalarUDF;
4747
use crate::udtf::PyTableFunction;
@@ -72,7 +72,6 @@ use datafusion::prelude::{
7272
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7373
};
7474
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
75-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7675
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7776
use pyo3::IntoPyObjectExt;
7877
use tokio::task::JoinHandle;
@@ -608,26 +607,9 @@ impl PySessionContext {
608607
name: &str,
609608
table_provider: Bound<'_, PyAny>,
610609
) -> PyDataFusionResult<()> {
611-
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
612-
let capsule = table_provider
613-
.getattr("__datafusion_table_provider__")?
614-
.call0()?;
615-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
616-
validate_pycapsule(capsule, "datafusion_table_provider")?;
617-
618-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
619-
let provider: ForeignTableProvider = provider.into();
620-
Arc::new(provider) as Arc<dyn TableProvider + Send>
621-
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
622-
py_table.table()
623-
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
624-
py_provider.as_table().table()
625-
} else {
626-
return Err(crate::errors::PyDataFusionError::Common(
627-
"Expected a Table or TableProvider.".to_string(),
628-
));
629-
};
630-
610+
let provider = pyany_to_table_provider(&table_provider).map_err(|_| {
611+
PyDataFusionError::Common("Expected a Table or TableProvider.".to_string())
612+
})?;
631613
self.ctx.register_table(name, provider)?;
632614
Ok(())
633615
}

src/dataframe.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ impl PyDataFrame {
268268
}
269269
}
270270

271-
pub(crate) fn to_view_provider(&self) -> Arc<dyn TableProvider + Send> {
271+
#[allow(clippy::wrong_self_convention)]
272+
pub(crate) fn into_view_provider(&self) -> Arc<dyn TableProvider + Send> {
272273
self.df.as_ref().clone().into_view()
273274
}
274275

@@ -404,12 +405,12 @@ impl PyDataFrame {
404405
/// where objects are shared
405406
/// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
406407
/// - we have not decided on the table_provider approach yet
407-
#[allow(clippy::wrong_self_convention)]
408+
#[allow(clippy::wrong_self_convention)]
408409
pub fn into_view(&self) -> PyDataFusionResult<PyTableProvider> {
409410
// Call the underlying Rust DataFrame::into_view method.
410411
// Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
411-
// so that we dont invalidate this PyDataFrame.
412-
let table_provider = self.to_view_provider();
412+
// so that we don't invalidate this PyDataFrame.
413+
let table_provider = self.into_view_provider();
413414
Ok(PyTableProvider::new(table_provider))
414415
}
415416

src/table.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use pyo3::types::PyCapsule;
2626

2727
use crate::catalog::PyTable;
2828
use crate::dataframe::PyDataFrame;
29+
use crate::dataset::Dataset;
2930
use crate::errors::{py_datafusion_err, PyDataFusionResult};
3031
use crate::utils::{get_tokio_runtime, validate_pycapsule};
3132

@@ -67,7 +68,7 @@ impl PyTableProvider {
6768
/// This method simply delegates to `DataFrame.into_view`.
6869
#[staticmethod]
6970
pub fn from_dataframe(df: &PyDataFrame) -> PyDataFusionResult<Self> {
70-
let table_provider = df.to_view_provider();
71+
let table_provider = df.into_view_provider();
7172
Ok(Self::new(table_provider))
7273
}
7374

@@ -99,3 +100,29 @@ impl PyTableProvider {
99100
PyCapsule::new(py, provider, Some(name.clone()))
100101
}
101102
}
103+
104+
pub(crate) fn pyany_to_table_provider(
105+
table_provider: &Bound<'_, PyAny>,
106+
) -> PyResult<Arc<dyn TableProvider + Send>> {
107+
if table_provider.hasattr("__datafusion_table_provider__")? {
108+
let capsule = table_provider
109+
.getattr("__datafusion_table_provider__")?
110+
.call0()?;
111+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
112+
validate_pycapsule(capsule, "datafusion_table_provider")?;
113+
114+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
115+
let provider: ForeignTableProvider = provider.into();
116+
Ok(Arc::new(provider) as Arc<dyn TableProvider + Send>)
117+
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
118+
Ok(py_table.table())
119+
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
120+
Ok(py_provider.as_table().table())
121+
} else if let Ok(inner) = table_provider.getattr("table") {
122+
pyany_to_table_provider(&inner)
123+
} else {
124+
let py = table_provider.py();
125+
let provider = Dataset::new(table_provider, py)?;
126+
Ok(Arc::new(provider) as Arc<dyn TableProvider + Send>)
127+
}
128+
}

0 commit comments

Comments
 (0)