Skip to content

Commit 2cb7abe

Browse files
Escape identifiers in relation aggregations (#272)
This validates expressions that are passed into functions of the relational api.
2 parents 2fbcd17 + 75fdbd3 commit 2cb7abe

2 files changed

Lines changed: 172 additions & 3 deletions

File tree

src/duckdb_py/pyrelation.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,36 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec
395395
function_name + "(" + function_parameter + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
396396
}
397397
for (idx_t i = 0; i < input.size(); i++) {
398+
// We parse the input as an expression to validate it.
399+
auto trimmed_input = input[i];
400+
StringUtil::Trim(trimmed_input);
401+
402+
unique_ptr<ParsedExpression> expression;
403+
try {
404+
auto expressions = Parser::ParseExpressionList(trimmed_input);
405+
if (expressions.size() == 1) {
406+
expression = std::move(expressions[0]);
407+
}
408+
} catch (const ParserException &) {
409+
// First attempt at parsing failed, the input might be a column name that needs quoting.
410+
auto quoted_input = KeywordHelper::WriteQuoted(trimmed_input, '"');
411+
auto expressions = Parser::ParseExpressionList(quoted_input);
412+
if (expressions.size() == 1 && expressions[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) {
413+
expression = std::move(expressions[0]);
414+
}
415+
}
416+
417+
if (!expression) {
418+
throw ParserException("Invalid column expression: %s", trimmed_input);
419+
}
420+
421+
// ToString() handles escaping for all expression types
422+
auto escaped_input = expression->ToString();
423+
398424
if (function_parameter.empty()) {
399-
expr += function_name + "(" + input[i] + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
425+
expr += function_name + "(" + escaped_input + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
400426
} else {
401-
expr += function_name + "(" + input[i] + "," + function_parameter +
427+
expr += function_name + "(" + escaped_input + "," + function_parameter +
402428
((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec;
403429
}
404430

@@ -587,7 +613,7 @@ unique_ptr<DuckDBPyRelation> DuckDBPyRelation::Product(const std::string &column
587613
unique_ptr<DuckDBPyRelation> DuckDBPyRelation::StringAgg(const std::string &column, const std::string &sep,
588614
const std::string &groups, const std::string &window_spec,
589615
const std::string &projected_columns) {
590-
auto string_agg_params = "\'" + sep + "\'";
616+
auto string_agg_params = KeywordHelper::WriteOptionallyQuoted(sep, '\'');
591617
return ApplyAggOrWin("string_agg", column, string_agg_params, groups, window_spec, projected_columns);
592618
}
593619

tests/fast/relational_api/test_rapi_aggregations.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,146 @@ def test_var_samp(self, table, f):
416416

417417
def test_describe(self, table):
418418
assert table.describe().fetchall() is not None
419+
420+
421+
class TestRAPIAggregationsColumnEscaping:
422+
"""Test that aggregate functions properly escape column names that need quoting."""
423+
424+
def test_reserved_keyword_column_name(self, duckdb_cursor):
425+
# Column name "select" is a reserved SQL keyword
426+
rel = duckdb_cursor.sql('select 1 as "select", 2 as "order"')
427+
result = rel.sum("select").fetchall()
428+
assert result == [(1,)]
429+
430+
result = rel.avg("order").fetchall()
431+
assert result == [(2.0,)]
432+
433+
def test_column_name_with_space(self, duckdb_cursor):
434+
rel = duckdb_cursor.sql('select 10 as "my column"')
435+
result = rel.sum("my column").fetchall()
436+
assert result == [(10,)]
437+
438+
def test_column_name_with_quotes(self, duckdb_cursor):
439+
# Column name containing a double quote
440+
rel = duckdb_cursor.sql('select 5 as "col""name"')
441+
result = rel.sum('col"name').fetchall()
442+
assert result == [(5,)]
443+
444+
def test_qualified_column_name(self, duckdb_cursor):
445+
# Qualified column name like table.column
446+
rel = duckdb_cursor.sql("select 42 as value")
447+
# When using qualified names, they should be properly escaped
448+
result = rel.sum("value").fetchall()
449+
assert result == [(42,)]
450+
451+
452+
class TestRAPIAggregationsExpressionPassthrough:
453+
"""Test that aggregate functions correctly pass through SQL expressions without escaping."""
454+
455+
def test_cast_expression(self, duckdb_cursor):
456+
# Cast expressions should pass through without being quoted
457+
rel = duckdb_cursor.sql("select 1 as v, 0 as f")
458+
result = rel.bool_and("v::BOOL").fetchall()
459+
assert result == [(True,)]
460+
461+
result = rel.bool_or("f::BOOL").fetchall()
462+
assert result == [(False,)]
463+
464+
def test_star_expression(self, duckdb_cursor):
465+
# Star (*) should pass through for count
466+
rel = duckdb_cursor.sql("select 1 as a union all select 2")
467+
result = rel.count("*").fetchall()
468+
assert result == [(2,)]
469+
470+
def test_arithmetic_expression(self, duckdb_cursor):
471+
# Arithmetic expressions should pass through
472+
rel = duckdb_cursor.sql("select 10 as a, 5 as b")
473+
result = rel.sum("a + b").fetchall()
474+
assert result == [(15,)]
475+
476+
def test_function_expression(self, duckdb_cursor):
477+
# Function calls should pass through
478+
rel = duckdb_cursor.sql("select -5 as v")
479+
result = rel.sum("abs(v)").fetchall()
480+
assert result == [(5,)]
481+
482+
def test_case_expression(self, duckdb_cursor):
483+
# CASE expressions should pass through
484+
rel = duckdb_cursor.sql("select 1 as v union all select 2 union all select 3")
485+
result = rel.sum("case when v > 1 then v else 0 end").fetchall()
486+
assert result == [(5,)]
487+
488+
489+
class TestRAPIAggregationsWithInvalidInput:
490+
"""Test that only expression can be used."""
491+
492+
def test_injection_with_semicolon_is_neutralized(self, duckdb_cursor):
493+
# Semicolon injection fails to parse as expression, gets quoted as identifier
494+
rel = duckdb_cursor.sql("select 1 as v")
495+
with pytest.raises(duckdb.BinderException, match="not found in FROM clause"):
496+
rel.sum("v; drop table agg; --").fetchall()
497+
498+
def test_injection_with_union_is_neutralized(self, duckdb_cursor):
499+
# UNION fails to parse as single expression, gets quoted
500+
rel = duckdb_cursor.sql("select 1 as v")
501+
with pytest.raises(duckdb.BinderException, match="not found in FROM clause"):
502+
rel.sum("v union select * from agg").fetchall()
503+
504+
def test_subquery_is_contained(self, duckdb_cursor):
505+
# Subqueries are valid expressions - they're contained within the aggregate
506+
# and cannot break out of the expression context
507+
rel = duckdb_cursor.sql("select 1 as v")
508+
# This executes sum((select 1)) = sum(1) = 1 - contained, not an injection
509+
result = rel.sum("(select 1)").fetchall()
510+
assert result == [(1,)]
511+
512+
def test_injection_closing_paren_is_neutralized(self, duckdb_cursor):
513+
# Adding a closing paren fails to parse, gets quoted
514+
rel = duckdb_cursor.sql("select 1 as v")
515+
with pytest.raises(duckdb.BinderException, match="not found in FROM clause"):
516+
rel.sum("v) from agg; drop table agg; --").fetchall()
517+
518+
def test_comment_is_harmless(self, duckdb_cursor):
519+
# SQL comments are stripped during parsing, so "v -- comment" parses as just "v"
520+
rel = duckdb_cursor.sql("select 1 as v")
521+
result = rel.sum("v -- this is ignored").fetchall()
522+
assert result == [(1,)]
523+
524+
def test_empty_expression_rejected(self, duckdb_cursor):
525+
# Empty or whitespace-only expressions should be rejected
526+
rel = duckdb_cursor.sql("select 1 as v")
527+
with pytest.raises(duckdb.ParserException):
528+
rel.sum("").fetchall()
529+
530+
def test_whitespace_only_expression_rejected(self, duckdb_cursor):
531+
# Whitespace-only expressions should be rejected
532+
rel = duckdb_cursor.sql("select 1 as v")
533+
with pytest.raises(duckdb.ParserException):
534+
rel.sum(" ").fetchall()
535+
536+
537+
class TestRAPIStringAggSeparatorEscaping:
538+
"""Test that string_agg separator is properly escaped as a string literal."""
539+
540+
def test_simple_separator(self, duckdb_cursor):
541+
rel = duckdb_cursor.sql("select 'a' as s union all select 'b' union all select 'c'")
542+
result = rel.string_agg("s", ",").fetchall()
543+
assert result == [("a,b,c",)]
544+
545+
def test_separator_with_single_quote(self, duckdb_cursor):
546+
# Separator containing a single quote should be properly escaped
547+
rel = duckdb_cursor.sql("select 'a' as s union all select 'b'")
548+
result = rel.string_agg("s", "','").fetchall()
549+
assert result == [("a','b",)]
550+
551+
def test_separator_with_special_chars(self, duckdb_cursor):
552+
rel = duckdb_cursor.sql("select 'x' as s union all select 'y'")
553+
result = rel.string_agg("s", " | ").fetchall()
554+
assert result == [("x | y",)]
555+
556+
def test_separator_injection_attempt(self, duckdb_cursor):
557+
# Attempt to inject via separator - should be safely quoted as string literal
558+
rel = duckdb_cursor.sql("select 'a' as s union all select 'b'")
559+
# This should NOT execute the injection - separator becomes a literal string
560+
result = rel.string_agg("s", "'); drop table agg; --").fetchall()
561+
assert result == [("a'); drop table agg; --b",)]

0 commit comments

Comments
 (0)