File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2020import pyarrow as pa
2121import pyarrow .dataset as ds
2222import pytest
23- from datafusion import SessionContext , Table
23+ from datafusion import SessionContext , Table , udtf
2424
2525
2626# Note we take in `database` as a variable even though we don't use
@@ -232,3 +232,19 @@ def test_in_end_to_end_python_providers(ctx: SessionContext):
232232 assert len (batches ) == 1
233233 assert batches [0 ].column (0 ) == pa .array ([1 , 2 , 3 ])
234234 assert batches [0 ].column (1 ) == pa .array ([4 , 5 , 6 ])
235+
236+
237+ def test_register_python_function_as_udtf (ctx : SessionContext ):
238+ basic_table = Table (ctx .sql ("SELECT 3 AS value" ))
239+
240+ @udtf ("my_table_function" )
241+ def my_table_function_udtf () -> Table :
242+ return basic_table
243+
244+ ctx .register_udtf (my_table_function_udtf )
245+
246+ result = ctx .sql ("SELECT * FROM my_table_function()" ).collect ()
247+ assert len (result ) == 1
248+ assert len (result [0 ]) == 1
249+ assert len (result [0 ][0 ]) == 1
250+ assert result [0 ][0 ][0 ].as_py () == 3
Original file line number Diff line number Diff line change @@ -21,12 +21,11 @@ use std::sync::Arc;
2121use crate :: errors:: { py_datafusion_err, to_datafusion_err} ;
2222use crate :: expr:: PyExpr ;
2323use crate :: table:: PyTable ;
24- use crate :: utils:: { table_provider_from_pycapsule , validate_pycapsule} ;
24+ use crate :: utils:: validate_pycapsule;
2525use datafusion:: catalog:: { TableFunctionImpl , TableProvider } ;
2626use datafusion:: error:: Result as DataFusionResult ;
2727use datafusion:: logical_expr:: Expr ;
2828use datafusion_ffi:: udtf:: { FFI_TableFunction , ForeignTableFunction } ;
29- use pyo3:: exceptions:: PyNotImplementedError ;
3029use pyo3:: types:: { PyCapsule , PyTuple } ;
3130
3231/// Represents a user defined table function
@@ -98,11 +97,7 @@ fn call_python_table_function(
9897 let provider_obj = func. call1 ( py, py_args) ?;
9998 let provider = provider_obj. bind ( py) ;
10099
101- table_provider_from_pycapsule ( provider) ?. ok_or_else ( || {
102- PyNotImplementedError :: new_err (
103- "__datafusion_table_provider__ does not exist on Table Provider object." ,
104- )
105- } )
100+ Ok :: < Arc < dyn TableProvider > , PyErr > ( PyTable :: new ( provider) ?. table )
106101 } )
107102 . map_err ( to_datafusion_err)
108103}
You can’t perform that action at this time.
0 commit comments