Skip to content

Commit 77a810e

Browse files
committed
Refactor async stream execution to use spawn_streams utility for improved signal handling
1 parent 9147e4d commit 77a810e

2 files changed

Lines changed: 26 additions & 14 deletions

File tree

src/dataframe.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@ use pyo3::exceptions::PyValueError;
4242
use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45-
use tokio::{runtime::Handle, task::JoinHandle};
45+
use tokio::runtime::Handle;
4646

4747
use crate::catalog::PyTable;
48-
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
48+
use crate::errors::{py_datafusion_err, PyDataFusionError};
4949
use crate::expr::sort_expr::to_sort_expressions;
5050
use crate::physical_plan::PyExecutionPlan;
5151
use crate::record_batch::PyRecordBatchStream;
5252
use crate::sql::logical::PyLogicalPlan;
5353
use crate::utils::{
54-
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, validate_pycapsule,
55-
wait_for_future,
54+
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, spawn_stream, spawn_streams,
55+
validate_pycapsule, wait_for_future,
5656
};
5757
use crate::{
5858
errors::PyDataFusionResult,
@@ -960,16 +960,9 @@ impl PyDataFrame {
960960
}
961961

962962
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
963-
// create a Tokio runtime to run the async code
964-
let rt = &get_tokio_runtime().0;
965963
let df = self.df.as_ref().clone();
966-
let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
967-
rt.spawn(async move { df.execute_stream_partitioned().await });
968-
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
969-
.map_err(py_datafusion_err)?
970-
.map_err(py_datafusion_err)?;
971-
972-
Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
964+
let streams = spawn_streams(py, async move { df.execute_stream_partitioned().await })?;
965+
Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
973966
}
974967

975968
/// Convert to pandas dataframe with pyarrow

src/utils.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,26 @@ where
9393
{
9494
let rt = &get_tokio_runtime().0;
9595
let handle: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> = rt.spawn(fut);
96-
wait_for_future(py, async { handle.await.map_err(to_datafusion_err) })???
96+
Ok(wait_for_future(py, async {
97+
handle.await.map_err(to_datafusion_err)
98+
})???)
99+
}
100+
101+
/// Spawn a partitioned [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
102+
/// while respecting Python signal handling.
103+
pub(crate) fn spawn_streams<F>(
104+
py: Python,
105+
fut: F,
106+
) -> PyDataFusionResult<Vec<SendableRecordBatchStream>>
107+
where
108+
F: Future<Output = datafusion::common::Result<Vec<SendableRecordBatchStream>>> + Send + 'static,
109+
{
110+
let rt = &get_tokio_runtime().0;
111+
let handle: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
112+
rt.spawn(fut);
113+
Ok(wait_for_future(py, async {
114+
handle.await.map_err(to_datafusion_err)
115+
})???)
97116
}
98117

99118
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {

0 commit comments

Comments
 (0)