Skip to content

Commit 760969e

Browse files
committed
Add tests for KeyboardInterrupt handling in __arrow_c_stream__ and improve async stream signal handling
1 parent 77a810e commit 760969e

2 files changed

Lines changed: 117 additions & 8 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,6 +2689,110 @@ def trigger_interrupt():
26892689
interrupt_thread.join(timeout=1.0)
26902690

26912691

2692+
def test_arrow_c_stream_interrupted():
2693+
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
2694+
2695+
Similar to ``test_collect_interrupted`` this test issues a long running
2696+
query, but consumes the results via ``__arrow_c_stream__``. It then raises
2697+
``KeyboardInterrupt`` in the main thread and verifies that the stream
2698+
iteration stops promptly with the appropriate exception.
2699+
"""
2700+
2701+
ctx = SessionContext()
2702+
2703+
batches = []
2704+
for i in range(10):
2705+
batch = pa.RecordBatch.from_arrays(
2706+
[
2707+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2708+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2709+
],
2710+
names=["a", "b"],
2711+
)
2712+
batches.append(batch)
2713+
2714+
ctx.register_record_batches("t1", [batches])
2715+
ctx.register_record_batches("t2", [batches])
2716+
2717+
df = ctx.sql(
2718+
"""
2719+
WITH t1_expanded AS (
2720+
SELECT
2721+
a,
2722+
b,
2723+
CAST(a AS DOUBLE) / 1.5 AS c,
2724+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2725+
FROM t1
2726+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2727+
),
2728+
t2_expanded AS (
2729+
SELECT
2730+
a,
2731+
b,
2732+
CAST(a AS DOUBLE) * 2.5 AS e,
2733+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2734+
FROM t2
2735+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2736+
)
2737+
SELECT
2738+
t1.a, t1.b, t1.c, t1.d,
2739+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2740+
FROM t1_expanded t1
2741+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2742+
WHERE t1.a > 100 AND t2.a > 100
2743+
"""
2744+
)
2745+
2746+
reader = pa.RecordBatchReader._import_from_c(df.__arrow_c_stream__())
2747+
2748+
interrupted = False
2749+
interrupt_error = None
2750+
query_started = threading.Event()
2751+
max_wait_time = 5.0
2752+
2753+
def trigger_interrupt():
2754+
start_time = time.time()
2755+
while not query_started.is_set():
2756+
time.sleep(0.1)
2757+
if time.time() - start_time > max_wait_time:
2758+
msg = f"Query did not start within {max_wait_time} seconds"
2759+
raise RuntimeError(msg)
2760+
2761+
thread_id = threading.main_thread().ident
2762+
if thread_id is None:
2763+
msg = "Cannot get main thread ID"
2764+
raise RuntimeError(msg)
2765+
2766+
exception = ctypes.py_object(KeyboardInterrupt)
2767+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2768+
ctypes.c_long(thread_id), exception
2769+
)
2770+
if res != 1:
2771+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2772+
ctypes.c_long(thread_id), ctypes.py_object(0)
2773+
)
2774+
msg = "Failed to raise KeyboardInterrupt in main thread"
2775+
raise RuntimeError(msg)
2776+
2777+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2778+
interrupt_thread.daemon = True
2779+
interrupt_thread.start()
2780+
2781+
try:
2782+
query_started.set()
2783+
# consume the reader which should block and be interrupted
2784+
reader.read_all()
2785+
except KeyboardInterrupt:
2786+
interrupted = True
2787+
except Exception as e: # pragma: no cover - unexpected errors
2788+
interrupt_error = e
2789+
2790+
if not interrupted:
2791+
pytest.fail(f"Stream was not interrupted; got error: {interrupt_error}")
2792+
2793+
interrupt_thread.join(timeout=1.0)
2794+
2795+
26922796
def test_show_select_where_no_rows(capsys) -> None:
26932797
ctx = SessionContext()
26942798
df = ctx.sql("SELECT 1 WHERE 1=0")

src/dataframe.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use pyo3::exceptions::PyValueError;
4242
use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45-
use tokio::runtime::Handle;
45+
use tokio::task::JoinHandle;
4646

4747
use crate::catalog::PyTable;
4848
use crate::errors::{py_datafusion_err, PyDataFusionError};
@@ -363,7 +363,6 @@ impl PyDataFrame {
363363
/// changes per batch.
364364
struct DataFrameStreamReader {
365365
stream: SendableRecordBatchStream,
366-
runtime: Handle,
367366
schema: SchemaRef,
368367
projection: Option<SchemaRef>,
369368
}
@@ -372,8 +371,15 @@ impl Iterator for DataFrameStreamReader {
372371
type Item = Result<RecordBatch, ArrowError>;
373372

374373
fn next(&mut self) -> Option<Self::Item> {
375-
match self.runtime.block_on(self.stream.next()) {
376-
Some(Ok(batch)) => {
374+
// Use wait_for_future to poll the underlying async stream while
375+
// respecting Python signal handling (e.g. ``KeyboardInterrupt``).
376+
// This mirrors the behaviour of other synchronous wrappers and
377+
// prevents blocking indefinitely when a Python interrupt is raised.
378+
let fut = self.stream.next();
379+
let result = Python::with_gil(|py| wait_for_future(py, fut));
380+
381+
match result {
382+
Ok(Some(Ok(batch))) => {
377383
let batch = if let Some(ref schema) = self.projection {
378384
match record_batch_into_schema(batch, schema.as_ref()) {
379385
Ok(b) => b,
@@ -384,8 +390,9 @@ impl Iterator for DataFrameStreamReader {
384390
};
385391
Some(Ok(batch))
386392
}
387-
Some(Err(e)) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
388-
None => None,
393+
Ok(Some(Err(e))) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
394+
Ok(None) => None,
395+
Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))),
389396
}
390397
}
391398
}
@@ -921,7 +928,6 @@ impl PyDataFrame {
921928
py: Python<'py>,
922929
requested_schema: Option<Bound<'py, PyCapsule>>,
923930
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
924-
let rt = &get_tokio_runtime().0;
925931
let df = self.df.as_ref().clone();
926932
let stream = spawn_stream(py, async move { df.execute_stream().await })?;
927933

@@ -942,7 +948,6 @@ impl PyDataFrame {
942948

943949
let reader = DataFrameStreamReader {
944950
stream,
945-
runtime: rt.handle().clone(),
946951
schema: schema_ref,
947952
projection,
948953
};

0 commit comments

Comments
 (0)