diff --git a/examples/cli.rs b/examples/cli.rs index 77e0566..e6e9562 100644 --- a/examples/cli.rs +++ b/examples/cli.rs @@ -1,12 +1,13 @@ use anyhow::{Context, Result}; use arrow::array::{ArrayRef, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::ScalarUDF; use datafusion::prelude::*; use datafusion_variant::{ - CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetUdf, VariantListConstruct, - VariantListInsert, VariantObjectConstruct, VariantObjectInsert, VariantObjectKeys, - VariantPretty, VariantToJsonUdf, + CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantExprPlanner, VariantGetUdf, + VariantListConstruct, VariantListInsert, VariantObjectConstruct, VariantObjectInsert, + VariantObjectKeys, VariantPretty, VariantToJsonUdf, }; use flate2::read::GzDecoder; use rustyline::error::ReadlineError; @@ -97,7 +98,7 @@ async fn main() -> Result<()> { let ctx = { let setup_start = Instant::now(); - let ctx = SessionContext::new(); + let mut ctx = SessionContext::new(); let schema = Schema::new(vec![Field::new("json_data", DataType::Utf8, false)]); let string_array: ArrayRef = Arc::new(StringArray::from(json_strings)); let batch = RecordBatch::try_new(Arc::new(schema), vec![string_array])?; @@ -119,6 +120,7 @@ async fn main() -> Result<()> { ctx.register_udf(ScalarUDF::new_from_impl(VariantListInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectKeys::default())); + ctx.register_expr_planner(Arc::new(VariantExprPlanner))?; let setup_duration = setup_start.elapsed(); println!( diff --git a/src/expr_planner.rs b/src/expr_planner.rs new file mode 100644 index 0000000..63e07c6 --- /dev/null +++ b/src/expr_planner.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use datafusion::{ + common::DFSchema, + error::DataFusionError, + logical_expr::{ + ScalarUDF, + expr::ScalarFunction, + planner::{ExprPlanner, PlannerResult, RawBinaryExpr}, + }, + prelude::Expr, + sql::sqlparser::ast::BinaryOperator, +}; + +use crate::VariantGetUdf; + +/// Custom [`ExprPlanner`] used to handle variant-specific syntax such as colon operator. +/// +/// Currently implements: +/// - Colon operator: short-hand syntax for `variant_get`. +#[derive(Debug)] +pub struct VariantExprPlanner; + +impl ExprPlanner for VariantExprPlanner { + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + _schema: &DFSchema, + ) -> Result, DataFusionError> { + match &expr.op { + BinaryOperator::Custom(s) if s == ":" => Ok(PlannerResult::Planned( + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new_from_impl(VariantGetUdf::default())), + vec![expr.left, expr.right], + )), + )), + _ => Ok(PlannerResult::Original(expr)), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index f944f84..9f798e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ mod shared; mod cast_to_variant; +mod expr_planner; mod impl_variant_get; mod is_variant_null; mod json_to_variant; @@ -19,6 +20,7 @@ mod variant_pretty; mod variant_to_json; pub use cast_to_variant::*; +pub use expr_planner::*; pub use is_variant_null::*; pub use json_to_variant::*; pub use variant_get::*; diff --git a/tests/sqllogictests.rs b/tests/sqllogictests.rs index f808d8c..5d2dfc8 100644 --- a/tests/sqllogictests.rs +++ b/tests/sqllogictests.rs @@ -1,14 +1,17 @@ +use datafusion::execution::FunctionRegistry; use datafusion::{logical_expr::ScalarUDF, prelude::*}; use datafusion_sqllogictest::{DataFusion, TestContext}; use datafusion_variant::{ - CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantGetBoolUdf, VariantGetFieldUdf, - VariantGetFloatUdf, VariantGetIntUdf, VariantGetJsonUdf, VariantGetStrUdf, VariantGetUdf, - VariantListConstruct, VariantListDelete, VariantListInsert, VariantObjectConstruct, - VariantObjectDelete, VariantObjectInsert, VariantObjectKeys, VariantPretty, VariantToJsonUdf, + CastToVariantUdf, IsVariantNullUdf, JsonToVariantUdf, VariantExprPlanner, VariantGetBoolUdf, + VariantGetFieldUdf, VariantGetFloatUdf, VariantGetIntUdf, VariantGetJsonUdf, VariantGetStrUdf, + VariantGetUdf, VariantListConstruct, VariantListDelete, VariantListInsert, + VariantObjectConstruct, VariantObjectDelete, VariantObjectInsert, VariantObjectKeys, + VariantPretty, VariantToJsonUdf, }; use indicatif::ProgressBar; use sqllogictest::strict_column_validator; use std::path::PathBuf; +use std::sync::Arc; #[tokio::test] async fn run_sqllogictests() -> Result<(), Box> { @@ -38,11 +41,12 @@ async fn run_sqllogictests() -> Result<(), Box> { .unwrap_or(&test_file) .to_path_buf(); - let ctx = if let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await { - test_ctx.session_ctx().clone() - } else { - SessionContext::new() - }; + let mut ctx = + if let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await { + test_ctx.session_ctx().clone() + } else { + SessionContext::new() + }; // register variant udfs ctx.register_udf(ScalarUDF::new_from_impl(VariantToJsonUdf::default())); @@ -64,6 +68,7 @@ async fn run_sqllogictests() -> Result<(), Box> { ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectInsert::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectDelete::default())); ctx.register_udf(ScalarUDF::new_from_impl(VariantObjectKeys::default())); + ctx.register_expr_planner(Arc::new(VariantExprPlanner))?; let pb = ProgressBar::new(24); diff --git a/tests/test_files/colon_operator.slt b/tests/test_files/colon_operator.slt new file mode 100644 index 0000000..903709d --- /dev/null +++ b/tests/test_files/colon_operator.slt @@ -0,0 +1,75 @@ +statement ok +CREATE TABLE json_data (id INT, json_str TEXT) AS VALUES + (1, '{"name": "Alice", "age": 30}'), + (2, '{"name": "Bob", "age": 25}'), + (3, '{"items": [1, 2, 3], "count": 3}'), + (4, '{"items": [{"name": "Sam", "age": 33}, "a", 3], "count": 3}'), + (5, 'null'), + (6, '"simple string"'), + (7, '["looooooooong string", "hehe", "a"]'), + (8, 'true'); + +# field access +query T +select variant_pretty(json_to_variant(json_str):name) from json_data; +---- +ShortString(ShortString("Alice")) +ShortString(ShortString("Bob")) +NULL +NULL +NULL +NULL +NULL +NULL + +# field + array access +query T +select variant_pretty(json_to_variant(json_str):items[0]) from json_data; +---- +NULL +NULL +Int8(1) +{"age": Int8(33), "name": ShortString(ShortString("Sam"))} +NULL +NULL +NULL +NULL + +# array access with non-array +query T +select variant_pretty(json_to_variant(json_str):age[0]) from json_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +# field -> array -> field access +query T +select variant_pretty(json_to_variant(json_str):items[0]["name"]) from json_data; +---- +NULL +NULL +NULL +ShortString(ShortString("Sam")) +NULL +NULL +NULL +NULL + +# field -> array -> field access but with single quotes +query T +select variant_pretty(json_to_variant(json_str):items[0]['name']) from json_data; +---- +NULL +NULL +NULL +ShortString(ShortString("Sam")) +NULL +NULL +NULL +NULL