|
| 1 | +#include "duckdb_python/arrow/polars_filter_pushdown.hpp" |
| 2 | + |
| 3 | +#include "duckdb/planner/filter/in_filter.hpp" |
| 4 | +#include "duckdb/planner/filter/optional_filter.hpp" |
| 5 | +#include "duckdb/planner/filter/conjunction_filter.hpp" |
| 6 | +#include "duckdb/planner/filter/constant_filter.hpp" |
| 7 | +#include "duckdb/planner/filter/struct_filter.hpp" |
| 8 | +#include "duckdb/planner/table_filter.hpp" |
| 9 | + |
| 10 | +#include "duckdb_python/pyconnection/pyconnection.hpp" |
| 11 | +#include "duckdb_python/python_objects.hpp" |
| 12 | + |
| 13 | +namespace duckdb { |
| 14 | + |
| 15 | +static py::object TransformFilterRecursive(TableFilter &filter, py::object col_expr, |
| 16 | + const ClientProperties &client_properties) { |
| 17 | + auto &import_cache = *DuckDBPyConnection::ImportCache(); |
| 18 | + |
| 19 | + switch (filter.filter_type) { |
| 20 | + case TableFilterType::CONSTANT_COMPARISON: { |
| 21 | + auto &constant_filter = filter.Cast<ConstantFilter>(); |
| 22 | + auto &constant = constant_filter.constant; |
| 23 | + auto &constant_type = constant.type(); |
| 24 | + |
| 25 | + // Check for NaN |
| 26 | + bool is_nan = false; |
| 27 | + if (constant_type.id() == LogicalTypeId::FLOAT) { |
| 28 | + is_nan = Value::IsNan(constant.GetValue<float>()); |
| 29 | + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { |
| 30 | + is_nan = Value::IsNan(constant.GetValue<double>()); |
| 31 | + } |
| 32 | + |
| 33 | + if (is_nan) { |
| 34 | + switch (constant_filter.comparison_type) { |
| 35 | + case ExpressionType::COMPARE_EQUAL: |
| 36 | + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: |
| 37 | + return col_expr.attr("is_nan")(); |
| 38 | + case ExpressionType::COMPARE_LESSTHAN: |
| 39 | + case ExpressionType::COMPARE_NOTEQUAL: |
| 40 | + return col_expr.attr("is_nan")().attr("__invert__")(); |
| 41 | + case ExpressionType::COMPARE_GREATERTHAN: |
| 42 | + return import_cache.polars.lit()(false); |
| 43 | + case ExpressionType::COMPARE_LESSTHANOREQUALTO: |
| 44 | + return import_cache.polars.lit()(true); |
| 45 | + default: |
| 46 | + return py::none(); |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + // Convert DuckDB Value to Python object |
| 51 | + auto py_value = PythonObject::FromValue(constant, constant_type, client_properties); |
| 52 | + |
| 53 | + switch (constant_filter.comparison_type) { |
| 54 | + case ExpressionType::COMPARE_EQUAL: |
| 55 | + return col_expr.attr("__eq__")(py_value); |
| 56 | + case ExpressionType::COMPARE_LESSTHAN: |
| 57 | + return col_expr.attr("__lt__")(py_value); |
| 58 | + case ExpressionType::COMPARE_GREATERTHAN: |
| 59 | + return col_expr.attr("__gt__")(py_value); |
| 60 | + case ExpressionType::COMPARE_LESSTHANOREQUALTO: |
| 61 | + return col_expr.attr("__le__")(py_value); |
| 62 | + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: |
| 63 | + return col_expr.attr("__ge__")(py_value); |
| 64 | + case ExpressionType::COMPARE_NOTEQUAL: |
| 65 | + return col_expr.attr("__ne__")(py_value); |
| 66 | + default: |
| 67 | + return py::none(); |
| 68 | + } |
| 69 | + } |
| 70 | + case TableFilterType::IS_NULL: { |
| 71 | + return col_expr.attr("is_null")(); |
| 72 | + } |
| 73 | + case TableFilterType::IS_NOT_NULL: { |
| 74 | + return col_expr.attr("is_not_null")(); |
| 75 | + } |
| 76 | + case TableFilterType::CONJUNCTION_AND: { |
| 77 | + auto &and_filter = filter.Cast<ConjunctionAndFilter>(); |
| 78 | + py::object expression = py::none(); |
| 79 | + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { |
| 80 | + auto child_expression = TransformFilterRecursive(*and_filter.child_filters[i], col_expr, client_properties); |
| 81 | + if (child_expression.is(py::none())) { |
| 82 | + continue; |
| 83 | + } |
| 84 | + if (expression.is(py::none())) { |
| 85 | + expression = std::move(child_expression); |
| 86 | + } else { |
| 87 | + expression = expression.attr("__and__")(child_expression); |
| 88 | + } |
| 89 | + } |
| 90 | + return expression; |
| 91 | + } |
| 92 | + case TableFilterType::CONJUNCTION_OR: { |
| 93 | + auto &or_filter = filter.Cast<ConjunctionOrFilter>(); |
| 94 | + py::object expression = py::none(); |
| 95 | + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { |
| 96 | + auto child_expression = TransformFilterRecursive(*or_filter.child_filters[i], col_expr, client_properties); |
| 97 | + if (child_expression.is(py::none())) { |
| 98 | + // Can't skip children in OR |
| 99 | + return py::none(); |
| 100 | + } |
| 101 | + if (expression.is(py::none())) { |
| 102 | + expression = std::move(child_expression); |
| 103 | + } else { |
| 104 | + expression = expression.attr("__or__")(child_expression); |
| 105 | + } |
| 106 | + } |
| 107 | + return expression; |
| 108 | + } |
| 109 | + case TableFilterType::STRUCT_EXTRACT: { |
| 110 | + auto &struct_filter = filter.Cast<StructFilter>(); |
| 111 | + auto child_col = col_expr.attr("struct").attr("field")(struct_filter.child_name); |
| 112 | + return TransformFilterRecursive(*struct_filter.child_filter, child_col, client_properties); |
| 113 | + } |
| 114 | + case TableFilterType::IN_FILTER: { |
| 115 | + auto &in_filter = filter.Cast<InFilter>(); |
| 116 | + py::list py_values; |
| 117 | + for (const auto &value : in_filter.values) { |
| 118 | + py_values.append(PythonObject::FromValue(value, value.type(), client_properties)); |
| 119 | + } |
| 120 | + return col_expr.attr("is_in")(py_values); |
| 121 | + } |
| 122 | + case TableFilterType::OPTIONAL_FILTER: { |
| 123 | + auto &optional_filter = filter.Cast<OptionalFilter>(); |
| 124 | + if (!optional_filter.child_filter) { |
| 125 | + return py::none(); |
| 126 | + } |
| 127 | + return TransformFilterRecursive(*optional_filter.child_filter, col_expr, client_properties); |
| 128 | + } |
| 129 | + default: |
| 130 | + // We skip DYNAMIC_FILTER, EXPRESSION_FILTER, BLOOM_FILTER |
| 131 | + return py::none(); |
| 132 | + } |
| 133 | +} |
| 134 | + |
| 135 | +py::object PolarsFilterPushdown::TransformFilter(const TableFilterSet &filter_collection, |
| 136 | + unordered_map<idx_t, string> &columns, |
| 137 | + const unordered_map<idx_t, idx_t> &filter_to_col, |
| 138 | + const ClientProperties &client_properties) { |
| 139 | + auto &import_cache = *DuckDBPyConnection::ImportCache(); |
| 140 | + auto &filters_map = filter_collection.filters; |
| 141 | + |
| 142 | + py::object expression = py::none(); |
| 143 | + for (auto &it : filters_map) { |
| 144 | + auto column_idx = it.first; |
| 145 | + auto &column_name = columns[column_idx]; |
| 146 | + auto col_expr = import_cache.polars.col()(column_name); |
| 147 | + |
| 148 | + auto child_expression = TransformFilterRecursive(*it.second, col_expr, client_properties); |
| 149 | + if (child_expression.is(py::none())) { |
| 150 | + continue; |
| 151 | + } |
| 152 | + if (expression.is(py::none())) { |
| 153 | + expression = std::move(child_expression); |
| 154 | + } else { |
| 155 | + expression = expression.attr("__and__")(child_expression); |
| 156 | + } |
| 157 | + } |
| 158 | + return expression; |
| 159 | +} |
| 160 | + |
| 161 | +} // namespace duckdb |
0 commit comments