Skip to content

Commit 1734218

Browse files
committed
Refactor next_stream function to use blocking lock for synchronous stream access
1 parent 3db9231 commit 1734218

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

src/record_batch.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use crate::utils::{wait_for_future, wait_for_stream_next};
2222
use datafusion::arrow::pyarrow::ToPyArrow;
2323
use datafusion::arrow::record_batch::RecordBatch;
2424
use datafusion::physical_plan::SendableRecordBatchStream;
25-
use futures::StreamExt;
2625
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
2726
use pyo3::prelude::*;
2827
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
@@ -91,10 +90,15 @@ async fn next_stream(
9190
stream: Arc<Mutex<SendableRecordBatchStream>>,
9291
sync: bool,
9392
) -> PyResult<PyRecordBatch> {
94-
let mut stream = stream.lock().await;
95-
match stream.next().await {
96-
Some(Ok(batch)) => Ok(batch.into()),
97-
Some(Err(e)) => Err(PyDataFusionError::from(e))?,
93+
let result = tokio::task::spawn_blocking(move || {
94+
let mut stream = stream.blocking_lock();
95+
Python::with_gil(|py| wait_for_stream_next(py, &mut stream))
96+
})
97+
.await
98+
.map_err(|e| PyDataFusionError::Common(e.to_string()))?;
99+
100+
match result.map_err(PyDataFusionError::from)? {
101+
Some(batch) => Ok(batch.into()),
98102
None => {
99103
// Depending on whether the iteration is sync or not, we raise either a
100104
// StopIteration or a StopAsyncIteration

0 commit comments

Comments
 (0)