Skip to content

Commit 8d8f219

Browse files
committed
Revert "Broaden AggregateUDF typing for PyCapsule support"
This reverts commit fad17cc.
1 parent fad17cc commit 8d8f219

1 file changed

Lines changed: 5 additions & 21 deletions

File tree

python/datafusion/user_defined.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,11 @@ class AggregateUDF:
293293
def __init__(
294294
self,
295295
name: str,
296-
accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
297-
input_types: list[pa.DataType] | None,
298-
return_type: pa.DataType | None,
299-
state_type: list[pa.DataType] | None,
300-
volatility: Volatility | str | None,
296+
accumulator: Callable[[], Accumulator],
297+
input_types: list[pa.DataType],
298+
return_type: pa.DataType,
299+
state_type: list[pa.DataType],
300+
volatility: Volatility | str,
301301
) -> None:
302302
"""Instantiate a user-defined aggregate function (UDAF).
303303
@@ -307,18 +307,6 @@ def __init__(
307307
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
308308
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
309309
return
310-
if (
311-
input_types is None
312-
or return_type is None
313-
or state_type is None
314-
or volatility is None
315-
):
316-
msg = (
317-
"`input_types`, `return_type`, `state_type`, and `volatility` "
318-
"must be provided when `accumulator` is callable."
319-
)
320-
raise TypeError(msg)
321-
322310
self._udaf = df_internal.AggregateUDF(
323311
name,
324312
accumulator,
@@ -362,10 +350,6 @@ def udaf(
362350
name: Optional[str] = None,
363351
) -> AggregateUDF: ...
364352

365-
@overload
366-
@staticmethod
367-
def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
368-
369353
@staticmethod
370354
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
371355
"""Create a new User-Defined Aggregate Function (UDAF).

0 commit comments

Comments
 (0)