Skip to content

Commit 7fbeeac

Browse files
committed
Refactor record batch streaming to use poll_next_batch for improved error handling
1 parent 58f4bd5 commit 7fbeeac

2 files changed

Lines changed: 15 additions & 9 deletions

File tree

src/dataframe.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ use crate::catalog::PyTable;
4848
use crate::errors::{py_datafusion_err, PyDataFusionError};
4949
use crate::expr::sort_expr::to_sort_expressions;
5050
use crate::physical_plan::PyExecutionPlan;
51-
use crate::record_batch::PyRecordBatchStream;
51+
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
5252
use crate::sql::logical::PyLogicalPlan;
5353
use crate::utils::{
5454
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, spawn_streams,
@@ -390,11 +390,11 @@ impl Iterator for DataFrameStreamReader {
390390
// respecting Python signal handling (e.g. ``KeyboardInterrupt``).
391391
// This mirrors the behaviour of other synchronous wrappers and
392392
// prevents blocking indefinitely when a Python interrupt is raised.
393-
let fut = self.stream.next();
393+
let fut = poll_next_batch(&mut self.stream);
394394
let result = Python::with_gil(|py| wait_for_future(py, fut));
395395

396396
match result {
397-
Ok(Some(Ok(batch))) => {
397+
Ok(Ok(Some(batch))) => {
398398
let batch = if let Some(ref schema) = self.projection {
399399
match record_batch_into_schema(batch, schema.as_ref()) {
400400
Ok(b) => b,
@@ -405,8 +405,8 @@ impl Iterator for DataFrameStreamReader {
405405
};
406406
Some(Ok(batch))
407407
}
408-
Ok(Some(Err(e))) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
409-
Ok(None) => None,
408+
Ok(Ok(None)) => None,
409+
Ok(Err(e)) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
410410
Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
411411
}
412412
}

src/record_batch.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,20 @@ impl PyRecordBatchStream {
8484
}
8585
}
8686

87+
pub(crate) async fn poll_next_batch(
88+
stream: &mut SendableRecordBatchStream,
89+
) -> datafusion::error::Result<Option<RecordBatch>> {
90+
stream.next().await.transpose()
91+
}
92+
8793
async fn next_stream(
8894
stream: Arc<Mutex<SendableRecordBatchStream>>,
8995
sync: bool,
9096
) -> PyResult<PyRecordBatch> {
9197
let mut stream = stream.lock().await;
92-
match stream.next().await {
93-
Some(Ok(batch)) => Ok(batch.into()),
94-
Some(Err(e)) => Err(PyDataFusionError::from(e))?,
95-
None => {
98+
match poll_next_batch(&mut stream).await {
99+
Ok(Some(batch)) => Ok(batch.into()),
100+
Ok(None) => {
96101
// Depending on whether the iteration is sync or not, we raise either a
97102
// StopIteration or a StopAsyncIteration
98103
if sync {
@@ -101,5 +106,6 @@ async fn next_stream(
101106
Err(PyStopAsyncIteration::new_err("stream exhausted"))
102107
}
103108
}
109+
Err(e) => Err(PyDataFusionError::from(e))?,
104110
}
105111
}

0 commit comments

Comments
 (0)