Skip to content

Commit 236ae71

Browse files
Tighten DBAPI conformance: type-object compare, qmark validation, dedupe
Three related PEP 249 / cleanup fixes: ISSUE-06 — cursor.description[i][1] is the wire ValueType integer (e.g. 10 for ISO8601). PEP 249 says the module type objects must compare equal to that value. _DBAPIType now accepts both uppercase SQL type name strings and wire ValueType ints. Added mappings for STRING/BINARY/NUMBER/DATETIME/ROWID to the matching wire codes. ISSUE-07 — reject mapping and set/frozenset bind parameters with ProgrammingError. PEP 249 specifies qmark requires a sequence. A dict silently iterated as keys before; sets would silently scramble positional bindings. ISSUE-13 — removed the duplicate _strip_leading_comments from aio/cursor.py; it now imports the helper from the sync cursor module. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2242238 commit 236ae71

5 files changed

Lines changed: 201 additions & 29 deletions

File tree

src/dqlitedbapi/aio/cursor.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,14 @@
33
from collections.abc import Sequence
44
from typing import TYPE_CHECKING, Any
55

6-
from dqlitedbapi.cursor import _convert_params, _convert_row
6+
from dqlitedbapi.cursor import _convert_params, _convert_row, _strip_leading_comments
77
from dqlitedbapi.exceptions import InterfaceError
88

99
if TYPE_CHECKING:
1010
from dqlitedbapi.aio.connection import AsyncConnection
1111

1212

13-
def _strip_leading_comments(sql: str) -> str:
14-
"""Strip leading SQL comments (-- and /* */) and whitespace."""
15-
s = sql.strip()
16-
while True:
17-
if s.startswith("--"):
18-
newline = s.find("\n")
19-
if newline == -1:
20-
return ""
21-
s = s[newline + 1 :].strip()
22-
elif s.startswith("/*"):
23-
end = s.find("*/")
24-
if end == -1:
25-
return s
26-
s = s[end + 2 :].strip()
27-
else:
28-
break
29-
return s
13+
__all__ = ["AsyncCursor", "_strip_leading_comments"]
3014

3115

3216
class AsyncCursor:

src/dqlitedbapi/cursor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""PEP 249 Cursor implementation for dqlite."""
22

3-
from collections.abc import Callable, Sequence
3+
from collections.abc import Callable, Mapping, Sequence
44
from typing import TYPE_CHECKING, Any
55

66
from dqlitewire.constants import ValueType
77

8-
from dqlitedbapi.exceptions import InterfaceError
8+
from dqlitedbapi.exceptions import InterfaceError, ProgrammingError
99
from dqlitedbapi.types import (
1010
_convert_bind_param,
1111
_datetime_from_iso8601,
@@ -35,8 +35,31 @@ def _convert_row(row: Sequence[Any], column_types: Sequence[int]) -> tuple[Any,
3535
return tuple(result)
3636

3737

38+
def _reject_non_sequence_params(params: Any) -> None:
39+
"""Reject mappings and unordered containers per PEP 249 qmark rules.
40+
41+
PEP 249: for ``qmark`` paramstyle "the sequence is mandatory and the
42+
driver will not accept mappings." We also reject ``set`` / ``frozenset``
43+
— they are sequences structurally but unordered, which silently
44+
scrambles positional bindings.
45+
"""
46+
if params is None:
47+
return
48+
if isinstance(params, Mapping):
49+
raise ProgrammingError(
50+
"qmark paramstyle requires a sequence; got a mapping. "
51+
"Use a list or tuple positionally matching the ? placeholders."
52+
)
53+
if isinstance(params, (set, frozenset)):
54+
raise ProgrammingError(
55+
"qmark paramstyle requires an ordered sequence; got a set. "
56+
"Use a list or tuple positionally matching the ? placeholders."
57+
)
58+
59+
3860
def _convert_params(params: Sequence[Any] | None) -> list[Any] | None:
3961
"""Convert driver-level bind parameters (e.g. datetime) to wire primitives."""
62+
_reject_non_sequence_params(params)
4063
if params is None:
4164
return None
4265
return [_convert_bind_param(p) for p in params]

src/dqlitedbapi/types.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import datetime
44
from typing import Any
55

6+
from dqlitewire.constants import ValueType
7+
68

79
# Type constructors
810
def Date(year: int, month: int, day: int) -> datetime.date: # noqa: N802
@@ -42,27 +44,69 @@ def Binary(data: bytes) -> bytes: # noqa: N802
4244
return bytes(data)
4345

4446

45-
# Type objects for column type checking
47+
# Type objects for column type checking.
48+
#
49+
# PEP 249: "These objects represent a data type as represented in the
50+
# database. The module exports these objects: STRING, BINARY, NUMBER,
51+
# DATETIME, ROWID. The module should export a comparison for these types
52+
# and the object returned in Cursor.description[i][1]."
53+
#
54+
# Cursor.description[i][1] here is a wire-level ``ValueType`` integer
55+
# (e.g. 10 for ISO8601). The type objects below compare equal to both
56+
# the uppercase SQL type name strings (for declared-type matching) and
57+
# the matching ``ValueType`` ints.
4658
class _DBAPIType:
47-
"""Base type for DB-API type objects."""
59+
"""Base type for DB-API type objects. Compares equal to matching
60+
uppercase SQL type names (str) and wire-level ``ValueType`` codes
61+
(int).
62+
"""
4863

49-
def __init__(self, *values: str) -> None:
50-
self.values = set(values)
64+
def __init__(self, *values: str | int | ValueType) -> None:
65+
normalized: set[str | int] = set()
66+
for v in values:
67+
if isinstance(v, ValueType):
68+
normalized.add(int(v))
69+
else:
70+
normalized.add(v)
71+
self.values = normalized
5172

5273
def __eq__(self, other: object) -> bool:
5374
if isinstance(other, str):
5475
return other.upper() in self.values
76+
if isinstance(other, ValueType):
77+
return int(other) in self.values
78+
if isinstance(other, int) and not isinstance(other, bool):
79+
return other in self.values
5580
return NotImplemented
5681

5782
def __hash__(self) -> int:
5883
return hash(frozenset(self.values))
5984

6085

61-
STRING = _DBAPIType("TEXT", "VARCHAR", "CHAR", "CLOB")
62-
BINARY = _DBAPIType("BLOB", "BINARY", "VARBINARY")
63-
NUMBER = _DBAPIType("INTEGER", "INT", "SMALLINT", "BIGINT", "REAL", "FLOAT", "DOUBLE", "NUMERIC")
64-
DATETIME = _DBAPIType("DATE", "TIME", "TIMESTAMP", "DATETIME")
65-
ROWID = _DBAPIType("ROWID", "INTEGER PRIMARY KEY")
86+
STRING = _DBAPIType("TEXT", "VARCHAR", "CHAR", "CLOB", ValueType.TEXT)
87+
BINARY = _DBAPIType("BLOB", "BINARY", "VARBINARY", ValueType.BLOB)
88+
NUMBER = _DBAPIType(
89+
"INTEGER",
90+
"INT",
91+
"SMALLINT",
92+
"BIGINT",
93+
"REAL",
94+
"FLOAT",
95+
"DOUBLE",
96+
"NUMERIC",
97+
ValueType.INTEGER,
98+
ValueType.FLOAT,
99+
ValueType.BOOLEAN,
100+
)
101+
DATETIME = _DBAPIType(
102+
"DATE",
103+
"TIME",
104+
"TIMESTAMP",
105+
"DATETIME",
106+
ValueType.ISO8601,
107+
ValueType.UNIXTIME,
108+
)
109+
ROWID = _DBAPIType("ROWID", "INTEGER PRIMARY KEY", ValueType.INTEGER)
66110

67111

68112
# Internal conversion helpers.

tests/test_param_validation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""PEP 249 qmark parameter validation tests (ISSUE-07).
2+
3+
PEP 249: for ``qmark`` paramstyle, "the sequence is mandatory and the
4+
driver will not accept mappings." We additionally reject ``set``/
5+
``frozenset`` because they are unordered and would silently scramble
6+
positional bindings.
7+
"""
8+
9+
import pytest
10+
11+
from dqlitedbapi.cursor import _reject_non_sequence_params
12+
from dqlitedbapi.exceptions import ProgrammingError
13+
14+
15+
class TestRejectMappings:
16+
def test_dict_rejected(self) -> None:
17+
with pytest.raises(ProgrammingError, match="mapping"):
18+
_reject_non_sequence_params({"x": 1})
19+
20+
def test_ordered_dict_rejected(self) -> None:
21+
from collections import OrderedDict
22+
23+
with pytest.raises(ProgrammingError, match="mapping"):
24+
_reject_non_sequence_params(OrderedDict(a=1))
25+
26+
27+
class TestRejectUnorderedSequences:
28+
def test_set_rejected(self) -> None:
29+
with pytest.raises(ProgrammingError, match="set"):
30+
_reject_non_sequence_params({1, 2, 3})
31+
32+
def test_frozenset_rejected(self) -> None:
33+
with pytest.raises(ProgrammingError, match="set"):
34+
_reject_non_sequence_params(frozenset({1, 2, 3}))
35+
36+
37+
class TestAccept:
38+
def test_list_accepted(self) -> None:
39+
_reject_non_sequence_params([1, 2, 3]) # no raise
40+
41+
def test_tuple_accepted(self) -> None:
42+
_reject_non_sequence_params((1, 2, 3)) # no raise
43+
44+
def test_none_accepted(self) -> None:
45+
_reject_non_sequence_params(None) # no raise
46+
47+
def test_empty_list_accepted(self) -> None:
48+
_reject_non_sequence_params([]) # no raise

tests/test_types_dbapi_compare.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""PEP 249 type-object comparison tests (ISSUE-06).
2+
3+
Per PEP 249 the module's type objects (STRING, BINARY, NUMBER, DATETIME,
4+
ROWID) must compare equal to whatever `cursor.description[i][1]` carries.
5+
We carry the wire `ValueType` integer there, so the type objects have to
6+
compare equal to the integer code too — not just to uppercase SQL type
7+
name strings.
8+
"""
9+
10+
import pytest
11+
from dqlitewire.constants import ValueType
12+
13+
from dqlitedbapi import BINARY, DATETIME, NUMBER, ROWID, STRING
14+
15+
16+
class TestStringType:
17+
def test_equals_sql_type_names(self) -> None:
18+
assert STRING == "TEXT"
19+
assert STRING == "varchar"
20+
assert STRING == "CLOB"
21+
22+
def test_equals_wire_text_value_type(self) -> None:
23+
assert STRING == ValueType.TEXT
24+
assert STRING == int(ValueType.TEXT)
25+
26+
def test_does_not_equal_non_text_wire_types(self) -> None:
27+
assert not (STRING == ValueType.INTEGER)
28+
assert not (STRING == ValueType.BLOB)
29+
assert not (STRING == ValueType.ISO8601)
30+
31+
32+
class TestBinaryType:
33+
def test_equals_blob_wire_type(self) -> None:
34+
assert BINARY == ValueType.BLOB
35+
assert BINARY == int(ValueType.BLOB)
36+
assert BINARY == "BLOB"
37+
38+
39+
class TestNumberType:
40+
@pytest.mark.parametrize("vt", [ValueType.INTEGER, ValueType.FLOAT, ValueType.BOOLEAN])
41+
def test_equals_numeric_wire_types(self, vt: ValueType) -> None:
42+
assert NUMBER == vt
43+
assert NUMBER == int(vt)
44+
45+
def test_does_not_equal_text(self) -> None:
46+
assert not (NUMBER == ValueType.TEXT)
47+
48+
49+
class TestDatetimeType:
50+
@pytest.mark.parametrize("vt", [ValueType.ISO8601, ValueType.UNIXTIME])
51+
def test_equals_datetime_wire_types(self, vt: ValueType) -> None:
52+
assert DATETIME == vt
53+
assert DATETIME == int(vt)
54+
55+
def test_equals_declared_type_names(self) -> None:
56+
assert DATETIME == "DATETIME"
57+
assert DATETIME == "DATE"
58+
assert DATETIME == "timestamp"
59+
60+
61+
class TestRowidType:
62+
def test_equals_integer_wire_type(self) -> None:
63+
assert ROWID == ValueType.INTEGER
64+
assert ROWID == int(ValueType.INTEGER)
65+
66+
67+
class TestHashability:
68+
def test_types_still_hashable_after_int_mix(self) -> None:
69+
# Mixed str/int contents mustn't break dict-key or set-member use.
70+
types_set = {STRING, BINARY, NUMBER, DATETIME, ROWID}
71+
assert len(types_set) == 5
72+
d = {STRING: "s", NUMBER: "n"}
73+
assert d[STRING] == "s"

0 commit comments

Comments
 (0)