Skip to content

Commit 116047e

Browse files
committed
Refactor run function to aggregate batches and update command-line arguments
1 parent 8cd62bc commit 116047e

1 file changed

Lines changed: 22 additions & 4 deletions

File tree

benchmarks/collect_gil_bench.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121
import time
2222

2323
import pyarrow as pa
24-
from datafusion import SessionContext
24+
from datafusion import SessionContext, col
25+
from datafusion import functions as f
2526

2627

2728
def run(
2829
n_batches: int = 8,
2930
batch_size: int = 1_000_000,
3031
n_partitions: int | None = None,
3132
) -> None:
33+
"""Aggregate column 'a' across partitions and report runtime."""
3234
ctx = SessionContext()
3335
batches = []
3436
for i in range(n_batches):
@@ -46,20 +48,36 @@ def run(
4648
df = ctx.create_dataframe(partitions)
4749

4850
start = time.perf_counter()
49-
df.collect()
51+
df.aggregate([], [f.sum(col("a"))]).collect()
5052
duration = time.perf_counter() - start
51-
print(f"{n_batches} batches collected in {duration:.3f}s")
53+
print(f"{n_batches} batches aggregated in {duration:.3f}s")
5254

5355

5456
if __name__ == "__main__":
5557
import argparse
5658

5759
parser = argparse.ArgumentParser()
60+
parser.add_argument(
61+
"--batches",
62+
type=int,
63+
default=8,
64+
help="number of input batches to generate",
65+
)
66+
parser.add_argument(
67+
"--batch-size",
68+
type=int,
69+
default=1_000_000,
70+
help="number of rows per batch",
71+
)
5872
parser.add_argument(
5973
"--partitions",
6074
type=int,
6175
default=None,
6276
help="number of partitions to create (defaults to one per batch)",
6377
)
6478
args = parser.parse_args()
65-
run(n_partitions=args.partitions)
79+
run(
80+
n_batches=args.batches,
81+
batch_size=args.batch_size,
82+
n_partitions=args.partitions,
83+
)

0 commit comments

Comments
 (0)