Skip to content

Commit ed14eab

Browse files
committed
Add tests for synchronous and asynchronous record batch stream retrieval
1 parent a2d0808 commit ed14eab

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
4+
def test_record_batch_stream_next(ctx):
5+
stream = ctx.sql("SELECT 1 as a").execute_stream()
6+
batch = next(stream)
7+
assert batch.to_pyarrow().num_rows == 1
8+
with pytest.raises(StopIteration):
9+
next(stream)
10+
11+
12+
@pytest.mark.asyncio
13+
async def test_record_batch_stream_anext(ctx):
14+
stream = ctx.sql("SELECT 1 as a").execute_stream()
15+
batch = await stream.__anext__()
16+
assert batch.to_pyarrow().num_rows == 1
17+
with pytest.raises(StopAsyncIteration):
18+
await stream.__anext__()

src/record_batch.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,7 @@ impl PyRecordBatchStream {
6161
#[pymethods]
6262
impl PyRecordBatchStream {
6363
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
64-
let mut stream = wait_for_future(py, self.stream.lock())?;
65-
match wait_for_stream_next(py, &mut stream).map_err(PyDataFusionError::from)? {
66-
Some(batch) => Ok(batch.into()),
67-
None => Err(PyStopIteration::new_err("stream exhausted")),
68-
}
64+
wait_for_future(py, next_stream(self.stream.clone(), true))?
6965
}
7066

7167
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {

0 commit comments

Comments
 (0)