Skip to content

Commit c633435

Browse files
committed
Refactor tests to use fail_collect fixture for DataFrame collect method
1 parent 462ddf4 commit c633435

3 files changed

Lines changed: 13 additions & 17 deletions

File tree

python/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import SessionContext
20+
from datafusion import DataFrame, SessionContext
2121
from pyarrow.csv import write_csv
2222

2323

@@ -49,3 +49,12 @@ def database(ctx, tmp_path):
4949
delimiter=",",
5050
schema_infer_max_records=10,
5151
)
52+
53+
54+
@pytest.fixture
55+
def fail_collect(monkeypatch):
56+
def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure path
57+
msg = "collect should not be called"
58+
raise AssertionError(msg)
59+
60+
monkeypatch.setattr(DataFrame, "collect", _fail_collect)

python/tests/test_dataframe.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,21 +1582,14 @@ def test_empty_to_arrow_table(df):
15821582
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
15831583

15841584

1585-
def test_arrow_c_stream_to_table(monkeypatch):
1585+
def test_arrow_c_stream_to_table(fail_collect):
15861586
ctx = SessionContext()
15871587

15881588
# Create a DataFrame with two separate record batches
15891589
batch1 = pa.record_batch([pa.array([1])], names=["a"])
15901590
batch2 = pa.record_batch([pa.array([2])], names=["a"])
15911591
df = ctx.create_dataframe([[batch1], [batch2]])
15921592

1593-
# Fail if the DataFrame is pre-collected
1594-
def fail_collect(self): # pragma: no cover - failure path
1595-
msg = "collect should not be called"
1596-
raise AssertionError(msg)
1597-
1598-
monkeypatch.setattr(DataFrame, "collect", fail_collect)
1599-
16001593
table = pa.Table.from_batches(df)
16011594
expected = pa.Table.from_batches([batch1, batch2])
16021595

python/tests/test_io.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pyarrow as pa
2020
import pytest
21-
from datafusion import DataFrame, column
21+
from datafusion import column
2222
from datafusion._testing import range_table
2323
from datafusion.io import read_avro, read_csv, read_json, read_parquet
2424

@@ -123,15 +123,9 @@ def test_arrow_c_stream_large_dataset(ctx):
123123
assert current_rss - start_rss < 50 * 1024 * 1024
124124

125125

126-
def test_table_from_batches_stream(ctx, monkeypatch):
126+
def test_table_from_batches_stream(ctx, fail_collect):
127127
df = range_table(ctx, 0, 10)
128128

129-
def fail_collect(self): # pragma: no cover - failure path
130-
msg = "collect should not be called"
131-
raise AssertionError(msg)
132-
133-
monkeypatch.setattr(DataFrame, "collect", fail_collect)
134-
135129
table = pa.Table.from_batches(df)
136130
assert table.shape == (10, 1)
137131
assert table.column_names == ["value"]

0 commit comments

Comments
 (0)