@@ -45,16 +45,16 @@ def test_ffi_aggregate_register():
4545
4646 result = ctx .sql ("select my_custom_sum(a) from test_table group by b" ).collect ()
4747
48- assert result
48+ assert len ( result ) == 2
4949 assert result [0 ].num_columns == 1
5050
51- # Normalizing table registration in _normalize_table_provider feeds the Rust layer
52- # an actual TableProvider, so collect() emits the grouped rows in a single record batch
53- # instead of two separate batches.
54- aggregates = pa .concat_arrays ([batch .column (0 ) for batch in result ])
51+ result = [r .column (0 ) for r in result ]
52+ expected = [
53+ pa .array ([3 ], type = pa .int64 ()),
54+ pa .array ([3 ], type = pa .int64 ()),
55+ ]
5556
56- assert len (aggregates ) == 2
57- assert aggregates .to_pylist () == [3 , 3 ]
57+ assert result == expected
5858
5959
6060def test_ffi_aggregate_call_directly ():
@@ -65,13 +65,13 @@ def test_ffi_aggregate_call_directly():
6565 ctx .table ("test_table" ).aggregate ([col ("b" )], [my_udaf (col ("a" ))]).collect ()
6666 )
6767
68- # Normalizing table registration in _normalize_table_provider feeds the Rust layer
69- # an actual TableProvider, so collect() emits the grouped rows in a single record batch
70- # instead of two separate batches.
71- assert result
68+ assert len (result ) == 2
7269 assert result [0 ].num_columns == 2
7370
74- aggregates = pa .concat_arrays ([batch .column (1 ) for batch in result ])
71+ result = [r .column (1 ) for r in result ]
72+ expected = [
73+ pa .array ([3 ], type = pa .int64 ()),
74+ pa .array ([3 ], type = pa .int64 ()),
75+ ]
7576
76- assert len (aggregates ) == 2
77- assert aggregates .to_pylist () == [3 , 3 ]
77+ assert result == expected
0 commit comments