Skip to content

Commit fad17cc

Browse files
committed
Broaden AggregateUDF typing for PyCapsule support
Enhance AggregateUDF to accept PyCapsule providers. Validate that callable accumulators provide type metadata prior to UDAF construction. Add overload for udaf for static type checkers to recognize PyCapsule-backed aggregate functions.
1 parent ac2222e commit fad17cc

1 file changed

Lines changed: 21 additions & 5 deletions

File tree

python/datafusion/user_defined.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,11 @@ class AggregateUDF:
293293
def __init__(
294294
self,
295295
name: str,
296-
accumulator: Callable[[], Accumulator],
297-
input_types: list[pa.DataType],
298-
return_type: pa.DataType,
299-
state_type: list[pa.DataType],
300-
volatility: Volatility | str,
296+
accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
297+
input_types: list[pa.DataType] | None,
298+
return_type: pa.DataType | None,
299+
state_type: list[pa.DataType] | None,
300+
volatility: Volatility | str | None,
301301
) -> None:
302302
"""Instantiate a user-defined aggregate function (UDAF).
303303
@@ -307,6 +307,18 @@ def __init__(
307307
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
308308
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
309309
return
310+
if (
311+
input_types is None
312+
or return_type is None
313+
or state_type is None
314+
or volatility is None
315+
):
316+
msg = (
317+
"`input_types`, `return_type`, `state_type`, and `volatility` "
318+
"must be provided when `accumulator` is callable."
319+
)
320+
raise TypeError(msg)
321+
310322
self._udaf = df_internal.AggregateUDF(
311323
name,
312324
accumulator,
@@ -350,6 +362,10 @@ def udaf(
350362
name: Optional[str] = None,
351363
) -> AggregateUDF: ...
352364

365+
@overload
366+
@staticmethod
367+
def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
368+
353369
@staticmethod
354370
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
355371
"""Create a new User-Defined Aggregate Function (UDAF).

0 commit comments

Comments
 (0)