Skip to content

Commit 903f63d

Browse files
committed
refactor: replace SessionState with SessionContext in PyDataFrame and related structures
1 parent debfb18 commit 903f63d

3 files changed

Lines changed: 62 additions & 63 deletions

File tree

src/context.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ impl PySessionContext {
434434
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
435435
let result = self.ctx.sql(query);
436436
let df = wait_for_future(py, result)??;
437-
Ok(PyDataFrame::new(df, self.ctx.state().into()))
437+
Ok(PyDataFrame::new(df, Arc::new(self.ctx.clone())))
438438
}
439439

440440
#[pyo3(signature = (query, options=None))]
@@ -451,7 +451,7 @@ impl PySessionContext {
451451
};
452452
let result = self.ctx.sql_with_options(query, options);
453453
let df = wait_for_future(py, result)??;
454-
Ok(PyDataFrame::new(df, self.ctx.state().into()))
454+
Ok(PyDataFrame::new(df, Arc::new(self.ctx.clone())))
455455
}
456456

457457
#[pyo3(signature = (partitions, name=None, schema=None))]
@@ -486,16 +486,14 @@ impl PySessionContext {
486486

487487
let table = wait_for_future(py, self._table(&table_name))??;
488488

489-
let df = PyDataFrame::new(table, self.ctx.state().into());
489+
let df = PyDataFrame::new(table, Arc::new(self.ctx.clone()));
490490
Ok(df)
491491
}
492492

493493
/// Create a DataFrame from an existing logical plan
494494
pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
495-
PyDataFrame::new(
496-
DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()),
497-
self.ctx.state().into(),
498-
)
495+
let ctx = Arc::new(self.ctx.clone());
496+
PyDataFrame::new(DataFrame::new(ctx.state(), plan.plan.as_ref().clone()), ctx)
499497
}
500498

501499
/// Construct datafusion dataframe from Python list
@@ -916,7 +914,7 @@ impl PySessionContext {
916914
let res = wait_for_future(py, self.ctx.table(name))
917915
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
918916
match res {
919-
Ok(df) => Ok(PyDataFrame::new(df, self.ctx.state().into())),
917+
Ok(df) => Ok(PyDataFrame::new(df, Arc::new(self.ctx.clone()))),
920918
Err(e) => {
921919
if let datafusion::error::DataFusionError::Plan(msg) = &e {
922920
if msg.contains("No table named") {
@@ -935,7 +933,7 @@ impl PySessionContext {
935933
pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
936934
Ok(PyDataFrame::new(
937935
self.ctx.read_empty()?,
938-
self.ctx.state().into(),
936+
Arc::new(self.ctx.clone()),
939937
))
940938
}
941939

@@ -976,7 +974,7 @@ impl PySessionContext {
976974
let result = self.ctx.read_json(path, options);
977975
wait_for_future(py, result)??
978976
};
979-
Ok(PyDataFrame::new(df, self.ctx.state().into()))
977+
Ok(PyDataFrame::new(df, Arc::new(self.ctx.clone())))
980978
}
981979

982980
#[allow(clippy::too_many_arguments)]
@@ -1026,12 +1024,12 @@ impl PySessionContext {
10261024
let paths = path.extract::<Vec<String>>()?;
10271025
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
10281026
let result = self.ctx.read_csv(paths, options);
1029-
let df = PyDataFrame::new(wait_for_future(py, result)??, self.ctx.state().into());
1027+
let df = PyDataFrame::new(wait_for_future(py, result)??, Arc::new(self.ctx.clone()));
10301028
Ok(df)
10311029
} else {
10321030
let path = path.extract::<String>()?;
10331031
let result = self.ctx.read_csv(path, options);
1034-
let df = PyDataFrame::new(wait_for_future(py, result)??, self.ctx.state().into());
1032+
let df = PyDataFrame::new(wait_for_future(py, result)??, Arc::new(self.ctx.clone()));
10351033
Ok(df)
10361034
}
10371035
}
@@ -1074,7 +1072,7 @@ impl PySessionContext {
10741072
.collect();
10751073

10761074
let result = self.ctx.read_parquet(path, options);
1077-
let df = PyDataFrame::new(wait_for_future(py, result)??, self.ctx.state().into());
1075+
let df = PyDataFrame::new(wait_for_future(py, result)??, Arc::new(self.ctx.clone()));
10781076
Ok(df)
10791077
}
10801078

@@ -1103,12 +1101,12 @@ impl PySessionContext {
11031101
let read_future = self.ctx.read_avro(path, options);
11041102
wait_for_future(py, read_future)??
11051103
};
1106-
Ok(PyDataFrame::new(df, self.ctx.state().into()))
1104+
Ok(PyDataFrame::new(df, Arc::new(self.ctx.clone())))
11071105
}
11081106

11091107
pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
11101108
let df = self.ctx.read_table(table.table())?;
1111-
Ok(PyDataFrame::new(df, self.ctx.state().into()))
1109+
Ok(PyDataFrame::new(df, Arc::new(self.ctx.clone())))
11121110
}
11131111

11141112
fn __repr__(&self) -> PyResult<String> {
@@ -1135,11 +1133,12 @@ impl PySessionContext {
11351133
part: usize,
11361134
py: Python,
11371135
) -> PyDataFusionResult<PyRecordBatchStream> {
1138-
let state = self.ctx.state();
1139-
let ctx: TaskContext = TaskContext::from(&state);
1136+
let ctx = Arc::new(self.ctx.clone());
1137+
let state = ctx.state();
1138+
let task_ctx: TaskContext = TaskContext::from(&state);
11401139
let plan = plan.plan.clone();
1141-
let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1142-
Ok(PyRecordBatchStream::new(stream, state.into()))
1140+
let stream = spawn_future(py, async move { plan.execute(part, Arc::new(task_ctx)) })?;
1141+
Ok(PyRecordBatchStream::new(stream, ctx))
11431142
}
11441143
}
11451144

0 commit comments

Comments
 (0)