Skip to content

Commit 362821a

Browse files
committed
Refactor stream handling to improve Python signal checking and add utility for polling record batch streams
1 parent 7209920 commit 362821a

3 files changed

Lines changed: 31 additions & 37 deletions

File tree

src/dataframe.rs

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,16 @@ use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
4545
use rayon::prelude::*;
46-
use tokio::time::{sleep, Duration};
4746

4847
use crate::catalog::PyTable;
49-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
48+
use crate::errors::{py_datafusion_err, PyDataFusionError};
5049
use crate::expr::sort_expr::to_sort_expressions;
5150
use crate::physical_plan::PyExecutionPlan;
5251
use crate::record_batch::PyRecordBatchStream;
5352
use crate::sql::logical::PyLogicalPlan;
5453
use crate::utils::{
5554
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_and_wait, validate_pycapsule,
56-
wait_for_future,
55+
wait_for_future, wait_for_stream_next,
5756
};
5857
use crate::{
5958
errors::PyDataFusionResult,
@@ -1018,30 +1017,10 @@ impl Iterator for ArrowStreamReader {
10181017
type Item = Result<RecordBatch, ArrowError>;
10191018

10201019
fn next(&mut self) -> Option<Self::Item> {
1021-
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
1022-
let rt = &get_tokio_runtime().0;
1023-
let fut = self.stream.next();
1024-
1025-
let result = Python::with_gil(|py| {
1026-
py.allow_threads(|| {
1027-
rt.block_on(async {
1028-
tokio::pin!(fut);
1029-
loop {
1030-
tokio::select! {
1031-
res = &mut fut => break res,
1032-
_ = sleep(INTERVAL_CHECK_SIGNALS) => {
1033-
if let Err(err) = Python::with_gil(|py| py.check_signals()) {
1034-
break Some(Err(to_datafusion_err(err)));
1035-
}
1036-
}
1037-
}
1038-
}
1039-
})
1040-
})
1041-
});
1020+
let result = Python::with_gil(|py| wait_for_stream_next(py, &mut self.stream));
10421021

10431022
match result {
1044-
Some(Ok(batch)) => {
1023+
Ok(Some(batch)) => {
10451024
let batch = if self.project {
10461025
match record_batch_into_schema(batch, self.schema.as_ref()) {
10471026
Ok(b) => b,
@@ -1052,8 +1031,8 @@ impl Iterator for ArrowStreamReader {
10521031
};
10531032
Some(Ok(batch))
10541033
}
1055-
Some(Err(e)) => Some(Err(ArrowError::from(e))),
1056-
None => None,
1034+
Ok(None) => None,
1035+
Err(e) => Some(Err(ArrowError::from(e))),
10571036
}
10581037
}
10591038
}

src/record_batch.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use std::sync::Arc;
1919

2020
use crate::errors::PyDataFusionError;
21-
use crate::utils::wait_for_future;
21+
use crate::utils::{wait_for_future, wait_for_stream_next};
2222
use datafusion::arrow::pyarrow::ToPyArrow;
2323
use datafusion::arrow::record_batch::RecordBatch;
2424
use datafusion::physical_plan::SendableRecordBatchStream;
@@ -59,17 +59,14 @@ impl PyRecordBatchStream {
5959
}
6060
}
6161

62-
pub(crate) async fn pull_next_batch(
63-
stream: &mut SendableRecordBatchStream,
64-
) -> Option<datafusion::common::Result<RecordBatch>> {
65-
stream.next().await
66-
}
67-
6862
#[pymethods]
6963
impl PyRecordBatchStream {
7064
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
71-
let stream = self.stream.clone();
72-
wait_for_future(py, next_stream(stream, true))?
65+
let mut stream = wait_for_future(py, self.stream.lock())?;
66+
match wait_for_stream_next(py, &mut stream).map_err(PyDataFusionError::from)? {
67+
Some(batch) => Ok(batch.into()),
68+
None => Err(PyStopIteration::new_err("stream exhausted")),
69+
}
7370
}
7471

7572
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
@@ -95,7 +92,7 @@ async fn next_stream(
9592
sync: bool,
9693
) -> PyResult<PyRecordBatch> {
9794
let mut stream = stream.lock().await;
98-
match pull_next_batch(&mut stream).await {
95+
match stream.next().await {
9996
Some(Ok(batch)) => Ok(batch.into()),
10097
Some(Err(e)) => Err(PyDataFusionError::from(e))?,
10198
None => {

src/utils.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ use crate::{
2020
errors::{to_datafusion_err, PyDataFusionError, PyDataFusionResult},
2121
TokioRuntime,
2222
};
23+
use datafusion::{arrow::record_batch::RecordBatch, physical_plan::SendableRecordBatchStream};
2324
use datafusion::{
2425
common::ScalarValue, execution::context::SessionContext, logical_expr::Volatility,
2526
};
27+
use futures::StreamExt;
2628
use pyo3::prelude::*;
2729
use pyo3::{exceptions::PyValueError, types::PyCapsule};
2830
use std::{future::Future, sync::OnceLock, time::Duration};
@@ -84,6 +86,22 @@ where
8486
})
8587
}
8688

89+
/// Poll `SendableRecordBatchStream::next` while checking Python signals.
90+
///
91+
/// This utility mirrors [`wait_for_future`] for streams and converts any
92+
/// Python interruptions into a [`DataFusionError`].
93+
pub fn wait_for_stream_next(
94+
py: Python,
95+
stream: &mut SendableRecordBatchStream,
96+
) -> datafusion::common::Result<Option<RecordBatch>> {
97+
match wait_for_future(py, stream.next()) {
98+
Ok(Some(Ok(batch))) => Ok(Some(batch)),
99+
Ok(Some(Err(e))) => Err(e),
100+
Ok(None) => Ok(None),
101+
Err(err) => Err(to_datafusion_err(err)),
102+
}
103+
}
104+
87105
pub fn spawn_and_wait<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
88106
where
89107
F: Future<Output = datafusion::common::Result<T>> + Send + 'static,

0 commit comments

Comments
 (0)