1818use crate :: dataset:: Dataset ;
1919use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
2020use crate :: table:: PyTableProvider ;
21- use crate :: utils:: { validate_pycapsule, wait_for_future} ;
21+ use crate :: utils:: { table_provider_from_pycapsule , validate_pycapsule, wait_for_future} ;
2222use async_trait:: async_trait;
2323use datafusion:: catalog:: { MemoryCatalogProvider , MemorySchemaProvider } ;
2424use datafusion:: common:: DataFusionError ;
@@ -28,7 +28,6 @@ use datafusion::{
2828 datasource:: { TableProvider , TableType } ,
2929} ;
3030use datafusion_ffi:: schema_provider:: { FFI_SchemaProvider , ForeignSchemaProvider } ;
31- use datafusion_ffi:: table_provider:: { FFI_TableProvider , ForeignTableProvider } ;
3231use pyo3:: exceptions:: PyKeyError ;
3332use pyo3:: prelude:: * ;
3433use pyo3:: types:: PyCapsule ;
@@ -197,16 +196,8 @@ impl PySchema {
197196 }
198197
199198 fn register_table ( & self , name : & str , table_provider : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
200- let provider = if table_provider. hasattr ( "__datafusion_table_provider__" ) ? {
201- let capsule = table_provider
202- . getattr ( "__datafusion_table_provider__" ) ?
203- . call0 ( ) ?;
204- let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
205- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
206-
207- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
208- let provider: ForeignTableProvider = provider. into ( ) ;
209- Arc :: new ( provider) as Arc < dyn TableProvider + Send >
199+ let provider = if let Some ( provider) = table_provider_from_pycapsule ( & table_provider) ? {
200+ provider
210201 } else {
211202 match table_provider. extract :: < PyTable > ( ) {
212203 Ok ( py_table) => py_table. table ,
@@ -308,15 +299,8 @@ impl RustWrappedPySchemaProvider {
308299 return Ok ( None ) ;
309300 }
310301
311- if py_table. hasattr ( "__datafusion_table_provider__" ) ? {
312- let capsule = py_table. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
313- let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
314- validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
315-
316- let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
317- let provider: ForeignTableProvider = provider. into ( ) ;
318-
319- Ok ( Some ( Arc :: new ( provider) as Arc < dyn TableProvider + Send > ) )
302+ if let Some ( provider) = table_provider_from_pycapsule ( & py_table) ? {
303+ Ok ( Some ( provider) )
320304 } else {
321305 if let Ok ( inner_table) = py_table. getattr ( "table" ) {
322306 if let Ok ( inner_table) = inner_table. extract :: < PyTable > ( ) {
0 commit comments