Skip to content

Commit c6af791

Browse files
committed
Refactor spawn_future and spawn_stream functions for improved flexibility and code reuse
1 parent 9237400 commit c6af791

1 file changed

Lines changed: 16 additions & 11 deletions

File tree

src/utils.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,27 @@ where
8585
})
8686
}
8787

88-
/// Spawn a [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
88+
/// Spawn a [`Future`] on the Tokio runtime and wait for completion
8989
/// while respecting Python signal handling.
90-
pub(crate) fn spawn_stream<F>(py: Python, fut: F) -> PyDataFusionResult<SendableRecordBatchStream>
90+
pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
9191
where
92-
F: Future<Output = datafusion::common::Result<SendableRecordBatchStream>> + Send + 'static,
92+
F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
93+
T: Send + 'static,
9394
{
9495
let rt = &get_tokio_runtime().0;
95-
let handle: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> = rt.spawn(fut);
96+
let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
9697
Ok(wait_for_future(py, async {
9798
handle.await.map_err(to_datafusion_err)
98-
})???)
99+
})??)
100+
}
101+
102+
/// Spawn a [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
103+
/// while respecting Python signal handling.
104+
pub(crate) fn spawn_stream<F>(py: Python, fut: F) -> PyDataFusionResult<SendableRecordBatchStream>
105+
where
106+
F: Future<Output = datafusion::common::Result<SendableRecordBatchStream>> + Send + 'static,
107+
{
108+
spawn_future(py, fut)
99109
}
100110

101111
/// Spawn a partitioned [`SendableRecordBatchStream`] on the Tokio runtime and wait for completion
@@ -107,12 +117,7 @@ pub(crate) fn spawn_streams<F>(
107117
where
108118
F: Future<Output = datafusion::common::Result<Vec<SendableRecordBatchStream>>> + Send + 'static,
109119
{
110-
let rt = &get_tokio_runtime().0;
111-
let handle: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
112-
rt.spawn(fut);
113-
Ok(wait_for_future(py, async {
114-
handle.await.map_err(to_datafusion_err)
115-
})???)
120+
spawn_future(py, fut)
116121
}
117122

118123
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {

0 commit comments

Comments
 (0)