Skip to content

Commit e70b16e

Browse files
committed
feat: enhance stream execution to maintain session context
1 parent 1043b8d commit e70b16e

4 files changed

Lines changed: 47 additions & 5 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,30 @@ async def test_execute_stream_partitioned_async(df):
15731573
assert not remaining_batches
15741574

15751575

1576+
def test_stream_keeps_context_alive():
1577+
ctx = SessionContext()
1578+
batch = pa.record_batch([pa.array([1])], names=["a"])
1579+
df = ctx.create_dataframe([[batch]])
1580+
1581+
stream = df.execute_stream()
1582+
capsule = df.__arrow_c_stream__()
1583+
1584+
del df
1585+
del ctx
1586+
gc.collect()
1587+
1588+
# PyRecordBatchStream should still yield the batch
1589+
batches = list(stream)
1590+
assert len(batches) == 1
1591+
assert batches[0].equals(batch)
1592+
1593+
# Arrow C stream should remain consumable
1594+
reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
1595+
table = pa.Table.from_batches(reader)
1596+
expected = pa.Table.from_batches([batch])
1597+
assert table.equals(expected)
1598+
1599+
15761600
def test_empty_to_arrow_table(df):
15771601
# Convert empty datafusion dataframe to pyarrow Table
15781602
pyarrow_table = df.limit(0).to_arrow_table()

src/context.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,10 +1129,11 @@ impl PySessionContext {
11291129
part: usize,
11301130
py: Python,
11311131
) -> PyDataFusionResult<PyRecordBatchStream> {
1132-
let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1132+
let state = self.ctx.state();
1133+
let ctx: TaskContext = TaskContext::from(&state);
11331134
let plan = plan.plan.clone();
11341135
let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1135-
Ok(PyRecordBatchStream::new(stream))
1136+
Ok(PyRecordBatchStream::new(stream, state))
11361137
}
11371138
}
11381139

src/dataframe.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3434
use datafusion::datasource::TableProvider;
3535
use datafusion::error::DataFusionError;
3636
use datafusion::execution::SendableRecordBatchStream;
37+
use datafusion::execution::session_state::SessionState;
3738
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3839
use datafusion::prelude::*;
3940
use datafusion_ffi::table_provider::FFI_TableProvider;
@@ -360,6 +361,8 @@ impl PyDataFrame {
360361
/// converted via `record_batch_into_schema` to apply schema changes per batch.
361362
struct PartitionedDataFrameStreamReader {
362363
streams: Vec<SendableRecordBatchStream>,
364+
// Hold a reference to the session state to keep the context alive
365+
_state: Arc<SessionState>,
363366
schema: SchemaRef,
364367
projection: Option<SchemaRef>,
365368
current: usize,
@@ -951,6 +954,7 @@ impl PyDataFrame {
951954
requested_schema: Option<Bound<'py, PyCapsule>>,
952955
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
953956
let df = self.df.as_ref().clone();
957+
let state = df.session_state().clone();
954958
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
955959

956960
let mut schema: Schema = self.df.schema().to_owned().into();
@@ -969,6 +973,7 @@ impl PyDataFrame {
969973
let schema_ref = Arc::new(schema.clone());
970974

971975
let reader = PartitionedDataFrameStreamReader {
976+
_state: state,
972977
streams,
973978
schema: schema_ref,
974979
projection,
@@ -985,14 +990,21 @@ impl PyDataFrame {
985990

986991
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
987992
let df = self.df.as_ref().clone();
993+
let state = df.session_state().clone();
988994
let stream = spawn_future(py, async move { df.execute_stream().await })?;
989-
Ok(PyRecordBatchStream::new(stream))
995+
Ok(PyRecordBatchStream::new(stream, state))
990996
}
991997

992998
fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
993999
let df = self.df.as_ref().clone();
1000+
let state = df.session_state().clone();
9941001
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
995-
Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
1002+
Ok(
1003+
streams
1004+
.into_iter()
1005+
.map(|stream| PyRecordBatchStream::new(stream, state.clone()))
1006+
.collect(),
1007+
)
9961008
}
9971009

9981010
/// Convert to pandas dataframe with pyarrow

src/record_batch.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::errors::PyDataFusionError;
2121
use crate::utils::wait_for_future;
2222
use datafusion::arrow::pyarrow::ToPyArrow;
2323
use datafusion::arrow::record_batch::RecordBatch;
24+
use datafusion::execution::session_state::SessionState;
2425
use datafusion::physical_plan::SendableRecordBatchStream;
2526
use futures::StreamExt;
2627
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
@@ -66,12 +67,16 @@ pub(crate) fn record_batches_to_pyarrow(
6667
#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
6768
pub struct PyRecordBatchStream {
6869
stream: Arc<Mutex<SendableRecordBatchStream>>,
70+
// Hold on to the session state to ensure the underlying context
71+
// remains alive for the duration of the stream
72+
_state: Arc<SessionState>,
6973
}
7074

7175
impl PyRecordBatchStream {
72-
pub fn new(stream: SendableRecordBatchStream) -> Self {
76+
pub fn new(stream: SendableRecordBatchStream, state: Arc<SessionState>) -> Self {
7377
Self {
7478
stream: Arc::new(Mutex::new(stream)),
79+
_state: state,
7580
}
7681
}
7782
}

0 commit comments

Comments
 (0)