2121import time
2222
2323import pyarrow as pa
24- from datafusion import SessionContext
24+ from datafusion import SessionContext , col
25+ from datafusion import functions as f
2526
2627
2728def run (
2829 n_batches : int = 8 ,
2930 batch_size : int = 1_000_000 ,
3031 n_partitions : int | None = None ,
3132) -> None :
33+ """Aggregate column 'a' across partitions and report runtime."""
3234 ctx = SessionContext ()
3335 batches = []
3436 for i in range (n_batches ):
@@ -46,20 +48,36 @@ def run(
4648 df = ctx .create_dataframe (partitions )
4749
4850 start = time .perf_counter ()
49- df .collect ()
51+ df .aggregate ([], [ f . sum ( col ( "a" ))]). collect ()
5052 duration = time .perf_counter () - start
51- print (f"{ n_batches } batches collected in { duration :.3f} s" )
53+ print (f"{ n_batches } batches aggregated in { duration :.3f} s" )
5254
5355
5456if __name__ == "__main__" :
5557 import argparse
5658
5759 parser = argparse .ArgumentParser ()
60+ parser .add_argument (
61+ "--batches" ,
62+ type = int ,
63+ default = 8 ,
64+ help = "number of input batches to generate" ,
65+ )
66+ parser .add_argument (
67+ "--batch-size" ,
68+ type = int ,
69+ default = 1_000_000 ,
70+ help = "number of rows per batch" ,
71+ )
5872 parser .add_argument (
5973 "--partitions" ,
6074 type = int ,
6175 default = None ,
6276 help = "number of partitions to create (defaults to one per batch)" ,
6377 )
6478 args = parser .parse_args ()
65- run (n_partitions = args .partitions )
79+ run (
80+ n_batches = args .batches ,
81+ batch_size = args .batch_size ,
82+ n_partitions = args .partitions ,
83+ )
0 commit comments