@@ -20,7 +20,7 @@ use crate::{
2020 common:: data_type:: PyScalarValue ,
2121 dataframe:: PyDataFrame ,
2222 dataset:: Dataset ,
23- errors:: { PyDataFusionError , PyDataFusionResult } ,
23+ errors:: { py_datafusion_err , PyDataFusionError , PyDataFusionResult } ,
2424 table:: PyTableProvider ,
2525 TokioRuntime ,
2626} ;
@@ -140,37 +140,25 @@ pub(crate) fn table_provider_from_pycapsule(
140140 }
141141}
142142
143- pub ( crate ) fn extract_table_provider (
144- table_like : & Bound < PyAny > ,
143+ pub ( crate ) fn coerce_table_provider (
144+ obj : & Bound < PyAny > ,
145145) -> PyDataFusionResult < Arc < dyn TableProvider > > {
146- if let Ok ( py_table) = table_like. extract :: < PyTable > ( ) {
147- return Ok ( py_table. table ( ) ) ;
148- }
149-
150- if let Ok ( py_provider) = table_like. extract :: < PyTableProvider > ( ) {
151- return Ok ( py_provider. into_inner ( ) ) ;
152- }
153-
154- if table_like. extract :: < PyDataFrame > ( ) . is_ok ( ) {
155- return Err ( PyDataFusionError :: Common ( EXPECTED_PROVIDER_MSG . to_string ( ) ) ) ;
156- }
157-
158- match table_provider_from_pycapsule ( table_like) {
159- Ok ( Some ( provider) ) => Ok ( provider) ,
160- Ok ( None ) => {
161- let py = table_like. py ( ) ;
162- match Dataset :: new ( table_like, py) {
163- Ok ( dataset) => Ok ( Arc :: new ( dataset) as Arc < dyn TableProvider > ) ,
164- Err ( err) => {
165- if err. is_instance_of :: < PyValueError > ( py) {
166- Err ( PyDataFusionError :: Common ( EXPECTED_PROVIDER_MSG . to_string ( ) ) )
167- } else {
168- Err ( err. into ( ) )
169- }
170- }
171- }
172- }
173- Err ( err) => Err ( err. into ( ) ) ,
146+ if let Ok ( py_table) = obj. extract :: < PyTable > ( ) {
147+ Ok ( py_table. table ( ) )
148+ } else if let Ok ( py_provider) = obj. extract :: < PyTableProvider > ( ) {
149+ Ok ( py_provider. into_inner ( ) )
150+ } else if obj. is_instance_of :: < PyDataFrame > ( )
151+ || obj
152+ . getattr ( "df" )
153+ . is_ok_and ( |inner| inner. is_instance_of :: < PyDataFrame > ( ) )
154+ {
155+ Err ( PyDataFusionError :: Common ( EXPECTED_PROVIDER_MSG . to_string ( ) ) )
156+ } else if let Some ( provider) = table_provider_from_pycapsule ( obj) ? {
157+ Ok ( provider)
158+ } else {
159+ let py = obj. py ( ) ;
160+ let provider = Dataset :: new ( obj, py) ?;
161+ Ok ( Arc :: new ( provider) as Arc < dyn TableProvider > )
174162 }
175163}
176164
0 commit comments