@@ -293,11 +293,11 @@ class AggregateUDF:
293293 def __init__ (
294294 self ,
295295 name : str ,
296- accumulator : Callable [[], Accumulator ] | AggregateUDFExportable ,
297- input_types : list [pa .DataType ] | None ,
298- return_type : pa .DataType | None ,
299- state_type : list [pa .DataType ] | None ,
300- volatility : Volatility | str | None ,
296+ accumulator : Callable [[], Accumulator ],
297+ input_types : list [pa .DataType ],
298+ return_type : pa .DataType ,
299+ state_type : list [pa .DataType ],
300+ volatility : Volatility | str ,
301301 ) -> None :
302302 """Instantiate a user-defined aggregate function (UDAF).
303303
@@ -307,18 +307,6 @@ def __init__(
307307 if hasattr (accumulator , "__datafusion_aggregate_udf__" ):
308308 self ._udaf = df_internal .AggregateUDF .from_pycapsule (accumulator )
309309 return
310- if (
311- input_types is None
312- or return_type is None
313- or state_type is None
314- or volatility is None
315- ):
316- msg = (
317- "`input_types`, `return_type`, `state_type`, and `volatility` "
318- "must be provided when `accumulator` is callable."
319- )
320- raise TypeError (msg )
321-
322310 self ._udaf = df_internal .AggregateUDF (
323311 name ,
324312 accumulator ,
@@ -362,10 +350,6 @@ def udaf(
362350 name : Optional [str ] = None ,
363351 ) -> AggregateUDF : ...
364352
365- @overload
366- @staticmethod
367- def udaf (accum : AggregateUDFExportable ) -> AggregateUDF : ...
368-
369353 @staticmethod
370354 def udaf (* args : Any , ** kwargs : Any ): # noqa: D417, C901
371355 """Create a new User-Defined Aggregate Function (UDAF).
0 commit comments