3434 from _typeshed import CapsuleType as _PyCapsule
3535
3636 _R = TypeVar ("_R" , bound = pa .DataType )
37- from collections .abc import Callable
37+ from collections .abc import Callable , Sequence
3838
3939
4040class Volatility (Enum ):
@@ -81,6 +81,27 @@ def __str__(self) -> str:
8181 return self .name .lower ()
8282
8383
84+ def data_type_or_field_to_field (value : pa .DataType | pa .Field , name : str ) -> pa .Field :
85+ """Helper function to return a Field from either a Field or DataType."""
86+ if isinstance (value , pa .Field ):
87+ return value
88+ return pa .field (name , type = value )
89+
90+
91+ def data_types_or_fields_to_field_list (
92+ inputs : Sequence [pa .Field | pa .DataType ] | pa .Field | pa .DataType ,
93+ ) -> list [pa .Field ]:
94+ """Helper function to return a list of Fields."""
95+ if isinstance (inputs , pa .DataType ):
96+ return [pa .field ("value" , type = inputs )]
97+ if isinstance (inputs , pa .Field ):
98+ return [inputs ]
99+
100+ return [
101+ data_type_or_field_to_field (v , f"value_{ idx } " ) for (idx , v ) in enumerate (inputs )
102+ ]
103+
104+
84105class ScalarUDFExportable (Protocol ):
85106 """Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
86107
@@ -103,8 +124,8 @@ def __init__(
103124 self ,
104125 name : str ,
105126 func : Callable [..., _R ],
106- input_types : pa . DataType | list [pa .DataType ],
107- return_type : _R ,
127+ input_fields : list [pa .Field ],
128+ return_field : _R ,
108129 volatility : Volatility | str ,
109130 ) -> None :
110131 """Instantiate a scalar user-defined function (UDF).
@@ -114,10 +135,10 @@ def __init__(
114135 if hasattr (func , "__datafusion_scalar_udf__" ):
115136 self ._udf = df_internal .ScalarUDF .from_pycapsule (func )
116137 return
117- if isinstance (input_types , pa .DataType ):
118- input_types = [input_types ]
138+ if isinstance (input_fields , pa .DataType ):
139+ input_fields = [input_fields ]
119140 self ._udf = df_internal .ScalarUDF (
120- name , func , input_types , return_type , str (volatility )
141+ name , func , input_fields , return_field , str (volatility )
121142 )
122143
123144 def __repr__ (self ) -> str :
@@ -136,8 +157,8 @@ def __call__(self, *args: Expr) -> Expr:
136157 @overload
137158 @staticmethod
138159 def udf (
139- input_types : list [pa .DataType ] ,
140- return_type : _R ,
160+ input_fields : Sequence [pa .DataType | pa . Field ] | pa . DataType | pa . Field ,
161+ return_field : pa . DataType | pa . Field ,
141162 volatility : Volatility | str ,
142163 name : str | None = None ,
143164 ) -> Callable [..., ScalarUDF ]: ...
@@ -146,8 +167,8 @@ def udf(
146167 @staticmethod
147168 def udf (
148169 func : Callable [..., _R ],
149- input_types : list [pa .DataType ] ,
150- return_type : _R ,
170+ input_fields : Sequence [pa .DataType | pa . Field ] | pa . DataType | pa . Field ,
171+ return_field : pa . DataType | pa . Field ,
151172 volatility : Volatility | str ,
152173 name : str | None = None ,
153174 ) -> ScalarUDF : ...
@@ -163,20 +184,24 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
163184 This class can be used both as either a function or a decorator.
164185
165186 Usage:
166- - As a function: ``udf(func, input_types, return_type , volatility, name)``.
167- - As a decorator: ``@udf(input_types, return_type , volatility, name)``.
187+ - As a function: ``udf(func, input_fields, return_field , volatility, name)``.
188+ - As a decorator: ``@udf(input_fields, return_field , volatility, name)``.
168189 When used a decorator, do **not** pass ``func`` explicitly.
169190
191+ In lieu of passing a PyArrow Field, you can pass a DataType for simplicity.
192+ When you do so, it will be assumed that the nullability of the inputs and
193+ output are True and that they have no metadata.
194+
170195 Args:
171196 func (Callable, optional): Only needed when calling as a function.
172197 Skip this argument when using `udf` as a decorator. If you have a Rust
173198 backed ScalarUDF within a PyCapsule, you can pass this parameter
174199 and ignore the rest. They will be determined directly from the
175200 underlying function. See the online documentation for more information.
176- input_types (list[pa.DataType]): The data types of the arguments
177- to ``func``. This list must be of the same length as the number of
178- arguments.
179- return_type (_R): The data type of the return value from the function.
201+ input_fields (list[pa.Field | pa. DataType]): The data types or Fields
202+ of the arguments to ``func``. This list must be of the same length
203+ as the number of arguments.
204+ return_field (_R): The field of the return value from the function.
180205 volatility (Volatility | str): See `Volatility` for allowed values.
181206 name (Optional[str]): A descriptive name for the function.
182207
@@ -196,12 +221,12 @@ def double_func(x):
196221 @udf([pa.int32()], pa.int32(), "volatile", "double_it")
197222 def double_udf(x):
198223 return x * 2
199- """
224+ """ # noqa: W505 E501
200225
201226 def _function (
202227 func : Callable [..., _R ],
203- input_types : list [pa .DataType ] ,
204- return_type : _R ,
228+ input_fields : Sequence [pa .DataType | pa . Field ] | pa . DataType | pa . Field ,
229+ return_field : pa . DataType | pa . Field ,
205230 volatility : Volatility | str ,
206231 name : str | None = None ,
207232 ) -> ScalarUDF :
@@ -213,23 +238,25 @@ def _function(
213238 name = func .__qualname__ .lower ()
214239 else :
215240 name = func .__class__ .__name__ .lower ()
241+ input_fields = data_types_or_fields_to_field_list (input_fields )
242+ return_field = data_type_or_field_to_field (return_field , "value" )
216243 return ScalarUDF (
217244 name = name ,
218245 func = func ,
219- input_types = input_types ,
220- return_type = return_type ,
246+ input_fields = input_fields ,
247+ return_field = return_field ,
221248 volatility = volatility ,
222249 )
223250
224251 def _decorator (
225- input_types : list [pa .DataType ] ,
226- return_type : _R ,
252+ input_fields : Sequence [pa .DataType | pa . Field ] | pa . DataType | pa . Field ,
253+ return_field : _R ,
227254 volatility : Volatility | str ,
228255 name : str | None = None ,
229256 ) -> Callable :
230257 def decorator (func : Callable ) -> Callable :
231258 udf_caller = ScalarUDF .udf (
232- func , input_types , return_type , volatility , name
259+ func , input_fields , return_field , volatility , name
233260 )
234261
235262 @functools .wraps (func )
@@ -260,8 +287,8 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
260287 return ScalarUDF (
261288 name = name ,
262289 func = func ,
263- input_types = None ,
264- return_type = None ,
290+ input_fields = None ,
291+ return_field = None ,
265292 volatility = None ,
266293 )
267294
0 commit comments