Skip to content

Commit a9333c6

Browse files
committed
Fix non-deterministic iteration in SessionStateBuilder
1 parent 5ff80e4 commit a9333c6

1 file changed

Lines changed: 58 additions & 9 deletions

File tree

datafusion/core/src/execution/session_state.rs

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,20 @@ impl SessionState {
978978
/// A builder to be used for building [`SessionState`]'s. Defaults will
979979
/// be used for all values unless explicitly provided.
980980
///
981+
/// Deduplicates function-registry map entries by keeping only entries whose key
982+
/// matches the canonical name. The session stores one hash map entry per alias
983+
/// plus the canonical name; filtering to canonical-name entries yields exactly
984+
/// one [`Arc`] per logical function.
985+
fn dedup_function_registry_by_canonical_name<T>(
986+
map: &HashMap<String, Arc<T>>,
987+
canonical_name: impl Fn(&T) -> &str,
988+
) -> Vec<Arc<T>> {
989+
map.iter()
990+
.filter(|(key, udf)| key.as_str() == canonical_name(udf.as_ref()))
991+
.map(|(_, udf)| Arc::clone(udf))
992+
.collect()
993+
}
994+
981995
/// See example on [`SessionState`]
982996
#[derive(Clone)]
983997
pub struct SessionStateBuilder {
@@ -1088,11 +1102,18 @@ impl SessionStateBuilder {
10881102
query_planner: Some(existing.query_planner),
10891103
catalog_list: Some(existing.catalog_list),
10901104
table_functions: Some(existing.table_functions),
1091-
scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()),
1092-
aggregate_functions: Some(
1093-
existing.aggregate_functions.into_values().collect_vec(),
1094-
),
1095-
window_functions: Some(existing.window_functions.into_values().collect_vec()),
1105+
scalar_functions: Some(dedup_function_registry_by_canonical_name(
1106+
&existing.scalar_functions,
1107+
|u| u.name(),
1108+
)),
1109+
aggregate_functions: Some(dedup_function_registry_by_canonical_name(
1110+
&existing.aggregate_functions,
1111+
|u| u.name(),
1112+
)),
1113+
window_functions: Some(dedup_function_registry_by_canonical_name(
1114+
&existing.window_functions,
1115+
|u| u.name(),
1116+
)),
10961117
extension_types: Some(existing.extension_types),
10971118
serializer_registry: Some(existing.serializer_registry),
10981119
file_formats: Some(existing.file_formats.into_values().collect_vec()),
@@ -1862,8 +1883,6 @@ impl ContextProvider for SessionContextProvider<'_> {
18621883
name: &str,
18631884
args: Vec<Expr>,
18641885
) -> datafusion_common::Result<Arc<dyn TableSource>> {
1865-
use datafusion_catalog::TableFunctionArgs;
1866-
18671886
let tbl_func = self
18681887
.state
18691888
.table_functions
@@ -1886,8 +1905,7 @@ impl ContextProvider for SessionContextProvider<'_> {
18861905
.and_then(|e| simplifier.simplify(e))
18871906
})
18881907
.collect::<datafusion_common::Result<Vec<_>>>()?;
1889-
let provider = tbl_func
1890-
.create_table_provider_with_args(TableFunctionArgs::new(&args, self.state))?;
1908+
let provider = tbl_func.create_table_provider(&args)?;
18911909

18921910
Ok(provider_as_source(provider))
18931911
}
@@ -2340,6 +2358,37 @@ mod tests {
23402358
Ok(())
23412359
}
23422360

2361+
#[test]
2362+
fn new_from_existing_preserves_scalar_udf_aliases() -> Result<()> {
2363+
use arrow::datatypes::DataType;
2364+
use datafusion_common::ScalarValue;
2365+
use datafusion_expr::registry::FunctionRegistry;
2366+
use datafusion_expr::{ColumnarValue, Volatility, create_udf};
2367+
2368+
let udf = create_udf(
2369+
"postgres_to_char",
2370+
vec![DataType::Utf8],
2371+
DataType::Utf8,
2372+
Volatility::Immutable,
2373+
Arc::new(|_args| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))),
2374+
)
2375+
.with_aliases(["to_char"]);
2376+
2377+
let mut state = SessionStateBuilder::new().build();
2378+
state.register_udf(Arc::new(udf))?;
2379+
2380+
assert_eq!(state.udf("postgres_to_char")?.name(), "postgres_to_char");
2381+
assert_eq!(state.udf("to_char")?.name(), "postgres_to_char");
2382+
2383+
let roundtrip = SessionStateBuilder::new_from_existing(state).build();
2384+
assert_eq!(roundtrip.udf("to_char")?.name(), "postgres_to_char");
2385+
assert_eq!(
2386+
roundtrip.udf("postgres_to_char")?.name(),
2387+
"postgres_to_char"
2388+
);
2389+
Ok(())
2390+
}
2391+
23432392
#[test]
23442393
fn test_session_state_with_optimizer_rules() {
23452394
#[derive(Default, Debug)]

0 commit comments

Comments
 (0)