Skip to content

Commit 6823613

Browse files
Fix RETURNING
1 parent 7ae10b8 commit 6823613

File tree

6 files changed

+74
-4
lines changed

6 files changed

+74
-4
lines changed
163 Bytes
Binary file not shown.

src/dqlitedbapi/aio/cursor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ async def execute(
6464

6565
params = list(parameters) if parameters else None
6666

67-
# Determine if this is a SELECT query
68-
is_query = operation.strip().upper().startswith(("SELECT", "PRAGMA", "EXPLAIN"))
67+
# Determine if this is a query that returns rows
68+
normalized = operation.strip().upper()
69+
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN")) or (
70+
" RETURNING " in normalized or normalized.endswith(" RETURNING")
71+
)
6972

7073
if is_query:
7174
assert conn._protocol is not None and conn._db_id is not None

src/dqlitedbapi/cursor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ async def _execute_async(self, operation: str, parameters: Sequence[Any] | None
7979
conn = await self._connection._get_async_connection()
8080
params = list(parameters) if parameters else None
8181

82-
# Determine if this is a SELECT query
83-
is_query = operation.strip().upper().startswith(("SELECT", "PRAGMA", "EXPLAIN"))
82+
# Determine if this is a query that returns rows
83+
normalized = operation.strip().upper()
84+
is_query = normalized.startswith(("SELECT", "PRAGMA", "EXPLAIN")) or (
85+
" RETURNING " in normalized or normalized.endswith(" RETURNING")
86+
)
8487

8588
if is_query:
8689
assert conn._protocol is not None and conn._db_id is not None

tests/__init__.py

Whitespace-only changes.

tests/integration/__init__.py

Whitespace-only changes.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Integration tests for INSERT/UPDATE/DELETE ... RETURNING support."""
2+
3+
import pytest
4+
5+
from dqlitedbapi import connect
6+
7+
8+
@pytest.mark.integration
9+
class TestReturning:
10+
def test_insert_returning(self, cluster_address: str) -> None:
11+
"""INSERT ... RETURNING should return rows via query path."""
12+
conn = connect(cluster_address, database="test_returning")
13+
cursor = conn.cursor()
14+
15+
cursor.execute("CREATE TABLE ret_test (id INTEGER PRIMARY KEY, name TEXT)")
16+
cursor.execute("INSERT INTO ret_test (id, name) VALUES (1, 'alice') RETURNING id, name")
17+
row = cursor.fetchone()
18+
assert row == (1, "alice")
19+
20+
cursor.execute("DROP TABLE ret_test")
21+
conn.close()
22+
23+
def test_insert_returning_multiple(self, cluster_address: str) -> None:
24+
"""INSERT ... RETURNING should support fetching all returned rows."""
25+
conn = connect(cluster_address, database="test_returning_multi")
26+
cursor = conn.cursor()
27+
28+
cursor.execute("CREATE TABLE ret_multi (id INTEGER PRIMARY KEY, val TEXT)")
29+
cursor.execute("INSERT INTO ret_multi (id, val) VALUES (1, 'a') RETURNING id, val")
30+
assert cursor.fetchone() == (1, "a")
31+
32+
cursor.execute("INSERT INTO ret_multi (id, val) VALUES (2, 'b') RETURNING val")
33+
assert cursor.fetchone() == ("b",)
34+
35+
cursor.execute("DROP TABLE ret_multi")
36+
conn.close()
37+
38+
def test_delete_returning(self, cluster_address: str) -> None:
39+
"""DELETE ... RETURNING should return deleted rows."""
40+
conn = connect(cluster_address, database="test_del_returning")
41+
cursor = conn.cursor()
42+
43+
cursor.execute("CREATE TABLE del_ret (id INTEGER PRIMARY KEY, name TEXT)")
44+
cursor.execute("INSERT INTO del_ret (id, name) VALUES (1, 'alice')")
45+
cursor.execute("DELETE FROM del_ret WHERE id = 1 RETURNING id, name")
46+
row = cursor.fetchone()
47+
assert row == (1, "alice")
48+
49+
cursor.execute("DROP TABLE del_ret")
50+
conn.close()
51+
52+
def test_update_returning(self, cluster_address: str) -> None:
53+
"""UPDATE ... RETURNING should return updated rows."""
54+
conn = connect(cluster_address, database="test_upd_returning")
55+
cursor = conn.cursor()
56+
57+
cursor.execute("CREATE TABLE upd_ret (id INTEGER PRIMARY KEY, name TEXT)")
58+
cursor.execute("INSERT INTO upd_ret (id, name) VALUES (1, 'alice')")
59+
cursor.execute("UPDATE upd_ret SET name = 'bob' WHERE id = 1 RETURNING id, name")
60+
row = cursor.fetchone()
61+
assert row == (1, "bob")
62+
63+
cursor.execute("DROP TABLE upd_ret")
64+
conn.close()

0 commit comments

Comments
 (0)