Skip to content

Commit f9b71ba

Browse files
Fix query detection for CTEs and SQL comments
CTE queries (WITH ... SELECT) were misrouted to exec_sql because the query detection only checked for SELECT/PRAGMA/EXPLAIN prefixes. SQL comments before statements (-- and /* */) also broke detection. Add _strip_leading_comments() helper and add WITH to the prefix list. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e08b0c3 commit f9b71ba

3 files changed

Lines changed: 114 additions & 6 deletions

File tree

src/dqlitedbapi/aio/cursor.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,25 @@
99
from dqlitedbapi.aio.connection import AsyncConnection
1010

1111

12+
def _strip_leading_comments(sql: str) -> str:
13+
"""Strip leading SQL comments (-- and /* */) and whitespace."""
14+
s = sql.strip()
15+
while True:
16+
if s.startswith("--"):
17+
newline = s.find("\n")
18+
if newline == -1:
19+
return ""
20+
s = s[newline + 1 :].strip()
21+
elif s.startswith("/*"):
22+
end = s.find("*/")
23+
if end == -1:
24+
return s
25+
s = s[end + 2 :].strip()
26+
else:
27+
break
28+
return s
29+
30+
1231
class AsyncCursor:
1332
"""Async database cursor."""
1433

@@ -61,9 +80,11 @@ async def execute(
6180
conn = await self._connection._ensure_connection()
6281
params = list(parameters) if parameters is not None else None
6382

64-
# Determine if this is a query that returns rows
65-
normalized = operation.strip().upper()
66-
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN")) or (
83+
# Determine if this is a query that returns rows.
84+
# Note: WITH ... INSERT/UPDATE/DELETE (without RETURNING) will be
85+
# misrouted to query_sql. This is a known limitation of the heuristic.
86+
normalized = _strip_leading_comments(operation).upper()
87+
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN", "WITH")) or (
6788
" RETURNING " in normalized or normalized.endswith(" RETURNING")
6889
)
6990

src/dqlitedbapi/cursor.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,25 @@
99
from dqlitedbapi.connection import Connection
1010

1111

12+
def _strip_leading_comments(sql: str) -> str:
13+
"""Strip leading SQL comments (-- and /* */) and whitespace."""
14+
s = sql.strip()
15+
while True:
16+
if s.startswith("--"):
17+
newline = s.find("\n")
18+
if newline == -1:
19+
return ""
20+
s = s[newline + 1 :].strip()
21+
elif s.startswith("/*"):
22+
end = s.find("*/")
23+
if end == -1:
24+
return s
25+
s = s[end + 2 :].strip()
26+
else:
27+
break
28+
return s
29+
30+
1231
class Cursor:
1332
"""PEP 249 compliant database cursor."""
1433

@@ -73,9 +92,11 @@ async def _execute_async(self, operation: str, parameters: Sequence[Any] | None
7392
conn = await self._connection._get_async_connection()
7493
params = list(parameters) if parameters is not None else None
7594

76-
# Determine if this is a query that returns rows
77-
normalized = operation.strip().upper()
78-
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN")) or (
95+
# Determine if this is a query that returns rows.
96+
# Note: WITH ... INSERT/UPDATE/DELETE (without RETURNING) will be
97+
# misrouted to query_sql. This is a known limitation of the heuristic.
98+
normalized = _strip_leading_comments(operation).upper()
99+
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN", "WITH")) or (
79100
" RETURNING " in normalized or normalized.endswith(" RETURNING")
80101
)
81102

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Integration tests for query detection: CTEs, comments, etc."""
2+
3+
import pytest
4+
5+
from dqlitedbapi import connect
6+
7+
8+
@pytest.mark.integration
9+
class TestQueryDetection:
10+
def test_cte_select(self, cluster_address: str) -> None:
11+
"""WITH ... SELECT (CTE) should return rows via query path."""
12+
with connect(cluster_address, database="test_cte_select") as conn:
13+
cursor = conn.cursor()
14+
cursor.execute("DROP TABLE IF EXISTS cte_test")
15+
cursor.execute("CREATE TABLE cte_test (id INTEGER PRIMARY KEY, name TEXT)")
16+
cursor.execute("INSERT INTO cte_test (id, name) VALUES (1, 'alice')")
17+
18+
cursor.execute("WITH names AS (SELECT id, name FROM cte_test) SELECT * FROM names")
19+
rows = cursor.fetchall()
20+
assert len(rows) == 1
21+
assert rows[0] == (1, "alice")
22+
23+
cursor.execute("DROP TABLE cte_test")
24+
25+
def test_comment_before_select(self, cluster_address: str) -> None:
26+
"""-- comment before SELECT should still return rows."""
27+
with connect(cluster_address, database="test_comment_select") as conn:
28+
cursor = conn.cursor()
29+
cursor.execute("DROP TABLE IF EXISTS comment_test")
30+
cursor.execute("CREATE TABLE comment_test (id INTEGER PRIMARY KEY)")
31+
cursor.execute("INSERT INTO comment_test (id) VALUES (1)")
32+
33+
cursor.execute("-- this is a comment\nSELECT * FROM comment_test")
34+
rows = cursor.fetchall()
35+
assert len(rows) == 1
36+
assert rows[0] == (1,)
37+
38+
cursor.execute("DROP TABLE comment_test")
39+
40+
def test_block_comment_before_select(self, cluster_address: str) -> None:
41+
"""/* block comment */ before SELECT should still return rows."""
42+
with connect(cluster_address, database="test_block_comment_select") as conn:
43+
cursor = conn.cursor()
44+
cursor.execute("DROP TABLE IF EXISTS bcomment_test")
45+
cursor.execute("CREATE TABLE bcomment_test (id INTEGER PRIMARY KEY)")
46+
cursor.execute("INSERT INTO bcomment_test (id) VALUES (42)")
47+
48+
cursor.execute("/* request_id=abc123 */ SELECT * FROM bcomment_test")
49+
rows = cursor.fetchall()
50+
assert len(rows) == 1
51+
assert rows[0] == (42,)
52+
53+
cursor.execute("DROP TABLE bcomment_test")
54+
55+
def test_recursive_cte(self, cluster_address: str) -> None:
56+
"""WITH RECURSIVE ... SELECT should return rows."""
57+
with connect(cluster_address, database="test_recursive_cte_select") as conn:
58+
cursor = conn.cursor()
59+
cursor.execute(
60+
"WITH RECURSIVE cnt(x) AS "
61+
"(VALUES(1) UNION ALL SELECT x+1 FROM cnt WHERE x<5) "
62+
"SELECT x FROM cnt"
63+
)
64+
rows = cursor.fetchall()
65+
assert len(rows) == 5
66+
assert rows == [(1,), (2,), (3,), (4,), (5,)]

0 commit comments

Comments
 (0)