Skip to content

Commit ac2222e

Browse files
committed
chore: add initial test for type checking in datafusion-ffi-example
1 parent daa26da commit ac2222e

1 file changed

Lines changed: 40 additions & 0 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pyarrow as pa
2+
from datafusion import udaf, SessionContext
3+
from datafusion.user_defined import Accumulator # base class for aggregators
4+
5+
# Define a simple test accumulator for demonstration:
6+
class TestAccumulator(Accumulator):
7+
def __init__(self) -> None:
8+
self.total = 0
9+
10+
def state(self) -> list[pa.Scalar]:
11+
return [pa.scalar(self.total)]
12+
13+
def update(self, *values: pa.Array) -> None:
14+
# Sum up integer values from the first argument
15+
self.total += sum(value.as_py() for value in values[0])
16+
17+
def merge(self, states: list[pa.Array]) -> None:
18+
# Assumes the state is a list with one scalar integer per actor
19+
self.total += sum(state[0].as_py() for state in states)
20+
21+
def evaluate(self) -> pa.Scalar:
22+
return pa.scalar(self.total)
23+
24+
# Create the test UDAF using TestAccumulator.
25+
# Note: the overload taking (accum, input_types, return_type, state_type, volatility, name)
26+
test_udaf = udaf(
27+
TestAccumulator, # accumulator function or type producing an Accumulator object
28+
[pa.int64()], # input types (list of one int64)
29+
pa.int64(), # return type
30+
[pa.int64()], # state type (list of one int64)
31+
"immutable", # volatility indicator
32+
name="test_udaf"
33+
)
34+
35+
# Register UDAF into a session context (if needed)
36+
ctx = SessionContext()
37+
ctx.register_udaf(test_udaf)
38+
39+
# The code should type check without error:
40+
print("Type checking passed for test_udaf!")

0 commit comments

Comments
 (0)