Skip to content

Commit 7b867fe

Browse files
committed
Refactor run function to aggregate batches and update command-line arguments
1 parent dc396c7 commit 7b867fe

1 file changed

Lines changed: 21 additions & 4 deletions

File tree

benchmarks/collect_gil_bench.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import time
2222

2323
import pyarrow as pa
24-
from datafusion import SessionContext
24+
from datafusion import SessionContext, col
25+
from datafusion import functions as f
2526

2627

2728
def run(
@@ -46,20 +47,36 @@ def run(
4647
df = ctx.create_dataframe(partitions)
4748

4849
start = time.perf_counter()
49-
df.collect()
50+
df.aggregate([], [f.sum(col("a"))]).collect()
5051
duration = time.perf_counter() - start
51-
print(f"{n_batches} batches collected in {duration:.3f}s")
52+
print(f"{n_batches} batches aggregated in {duration:.3f}s")
5253

5354

5455
if __name__ == "__main__":
5556
import argparse
5657

5758
parser = argparse.ArgumentParser()
59+
parser.add_argument(
60+
"--batches",
61+
type=int,
62+
default=8,
63+
help="number of input batches to generate",
64+
)
65+
parser.add_argument(
66+
"--batch-size",
67+
type=int,
68+
default=1_000_000,
69+
help="number of rows per batch",
70+
)
5871
parser.add_argument(
5972
"--partitions",
6073
type=int,
6174
default=None,
6275
help="number of partitions to create (defaults to one per batch)",
6376
)
6477
args = parser.parse_args()
65-
run(n_partitions=args.partitions)
78+
run(
79+
n_batches=args.batches,
80+
batch_size=args.batch_size,
81+
n_partitions=args.partitions,
82+
)

0 commit comments

Comments
 (0)