Skip to content

Commit e6a21f6

Browse files
committed
fix: refactor table provider handling to use table_provider_from_pycapsule utility
1 parent c873d2b commit e6a21f6

4 files changed

Lines changed: 41 additions & 51 deletions

File tree

src/catalog.rs

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2020
use crate::table::PyTableProvider;
21-
use crate::utils::{validate_pycapsule, wait_for_future};
21+
use crate::utils::{table_provider_from_pycapsule, validate_pycapsule, wait_for_future};
2222
use async_trait::async_trait;
2323
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
2424
use datafusion::common::DataFusionError;
@@ -28,7 +28,6 @@ use datafusion::{
2828
datasource::{TableProvider, TableType},
2929
};
3030
use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider};
31-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
3231
use pyo3::exceptions::PyKeyError;
3332
use pyo3::prelude::*;
3433
use pyo3::types::PyCapsule;
@@ -197,16 +196,8 @@ impl PySchema {
197196
}
198197

199198
fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> {
200-
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
201-
let capsule = table_provider
202-
.getattr("__datafusion_table_provider__")?
203-
.call0()?;
204-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
205-
validate_pycapsule(capsule, "datafusion_table_provider")?;
206-
207-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
208-
let provider: ForeignTableProvider = provider.into();
209-
Arc::new(provider) as Arc<dyn TableProvider + Send>
199+
let provider = if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
200+
provider
210201
} else {
211202
match table_provider.extract::<PyTable>() {
212203
Ok(py_table) => py_table.table,
@@ -308,15 +299,8 @@ impl RustWrappedPySchemaProvider {
308299
return Ok(None);
309300
}
310301

311-
if py_table.hasattr("__datafusion_table_provider__")? {
312-
let capsule = py_table.getattr("__datafusion_table_provider__")?.call0()?;
313-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
314-
validate_pycapsule(capsule, "datafusion_table_provider")?;
315-
316-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
317-
let provider: ForeignTableProvider = provider.into();
318-
319-
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider + Send>))
302+
if let Some(provider) = table_provider_from_pycapsule(&py_table)? {
303+
Ok(Some(provider))
320304
} else {
321305
if let Ok(inner_table) = py_table.getattr("table") {
322306
if let Ok(inner_table) = inner_table.extract::<PyTable>() {

src/context.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ use crate::udaf::PyAggregateUDF;
4646
use crate::udf::PyScalarUDF;
4747
use crate::udtf::PyTableFunction;
4848
use crate::udwf::PyWindowUDF;
49-
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
49+
use crate::utils::{
50+
get_global_ctx, get_tokio_runtime, table_provider_from_pycapsule, validate_pycapsule,
51+
wait_for_future,
52+
};
5053
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5154
use datafusion::arrow::pyarrow::PyArrowType;
5255
use datafusion::arrow::record_batch::RecordBatch;
@@ -72,7 +75,6 @@ use datafusion::prelude::{
7275
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7376
};
7477
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
75-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
7678
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7779
use pyo3::IntoPyObjectExt;
7880
use tokio::task::JoinHandle;
@@ -608,16 +610,8 @@ impl PySessionContext {
608610
name: &str,
609611
table_provider: Bound<'_, PyAny>,
610612
) -> PyDataFusionResult<()> {
611-
let provider = if table_provider.hasattr("__datafusion_table_provider__")? {
612-
let capsule = table_provider
613-
.getattr("__datafusion_table_provider__")?
614-
.call0()?;
615-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
616-
validate_pycapsule(capsule, "datafusion_table_provider")?;
617-
618-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
619-
let provider: ForeignTableProvider = provider.into();
620-
Arc::new(provider) as Arc<dyn TableProvider + Send>
613+
let provider = if let Some(provider) = table_provider_from_pycapsule(&table_provider)? {
614+
provider
621615
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
622616
py_table.table()
623617
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {

src/udtf.rs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ use std::sync::Arc;
2121
use crate::errors::{py_datafusion_err, to_datafusion_err};
2222
use crate::expr::PyExpr;
2323
use crate::table::PyTableProvider;
24-
use crate::utils::validate_pycapsule;
24+
use crate::utils::{table_provider_from_pycapsule, validate_pycapsule};
2525
use datafusion::catalog::{TableFunctionImpl, TableProvider};
2626
use datafusion::error::Result as DataFusionResult;
2727
use datafusion::logical_expr::Expr;
28-
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2928
use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
3029
use pyo3::exceptions::PyNotImplementedError;
3130
use pyo3::types::{PyCapsule, PyTuple};
@@ -99,20 +98,11 @@ fn call_python_table_function(
9998
let provider_obj = func.call1(py, py_args)?;
10099
let provider = provider_obj.bind(py);
101100

102-
if provider.hasattr("__datafusion_table_provider__")? {
103-
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
104-
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
105-
validate_pycapsule(capsule, "datafusion_table_provider")?;
106-
107-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
108-
let provider: ForeignTableProvider = provider.into();
109-
110-
Ok(Arc::new(provider) as Arc<dyn TableProvider + Send>)
111-
} else {
112-
Err(PyNotImplementedError::new_err(
101+
table_provider_from_pycapsule(provider)?.ok_or_else(|| {
102+
PyNotImplementedError::new_err(
113103
"__datafusion_table_provider__ does not exist on Table Provider object.",
114-
))
115-
}
104+
)
105+
})
116106
})
117107
.map_err(to_datafusion_err)
118108
}

src/utils.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,21 @@
1717

1818
use crate::{
1919
common::data_type::PyScalarValue,
20-
errors::{PyDataFusionError, PyDataFusionResult},
20+
errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult},
2121
TokioRuntime,
2222
};
2323
use datafusion::{
24-
common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
24+
common::ScalarValue, datasource::TableProvider, execution::context::SessionContext,
25+
logical_expr::Volatility,
2526
};
27+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2628
use pyo3::prelude::*;
2729
use pyo3::{exceptions::PyValueError, types::PyCapsule};
28-
use std::{future::Future, sync::OnceLock, time::Duration};
30+
use std::{
31+
future::Future,
32+
sync::{Arc, OnceLock},
33+
time::Duration,
34+
};
2935
use tokio::{runtime::Runtime, time::sleep};
3036
/// Utility to get the Tokio Runtime from Python
3137
#[inline]
@@ -116,6 +122,22 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe
116122
Ok(())
117123
}
118124

125+
pub(crate) fn table_provider_from_pycapsule(
126+
obj: &Bound<PyAny>,
127+
) -> PyResult<Option<Arc<dyn TableProvider + Send>>> {
128+
if obj.hasattr("__datafusion_table_provider__")? {
129+
let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?;
130+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
131+
validate_pycapsule(capsule, "datafusion_table_provider")?;
132+
133+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
134+
let provider: ForeignTableProvider = provider.into();
135+
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider + Send>))
136+
} else {
137+
Ok(None)
138+
}
139+
}
140+
119141
pub(crate) fn py_obj_to_scalar_value(py: Python, obj: PyObject) -> PyResult<ScalarValue> {
120142
// convert Python object to PyScalarValue to ScalarValue
121143

0 commit comments

Comments
 (0)