Skip to content

Commit 97ce996

Browse files
fix: reject bare str/bytes as params with clear TypeError
params was typed as Sequence[Any], which includes str/bytes; calling execute("SELECT ?", "alice") type-checked, iterated "alice" character by character, and sent 5 parameters to the server. The user saw a confusing "wrong parameter count" server error instead of a clean client-side type mismatch. Add a _validate_params guard at every client entry point (execute, query_raw, fetch, fetchall, fetchval) that raises TypeError with a helpful hint when params is str or bytes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9dcca52 commit 97ce996

2 files changed

Lines changed: 24 additions & 0 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ def _invalidate(self, cause: BaseException | None = None) -> None:
238238
if cause is not None:
239239
self._invalidation_cause = cause
240240

241+
@staticmethod
242+
def _validate_params(params: Sequence[Any] | None) -> None:
243+
"""Reject bare str/bytes params to catch the ``execute("?", "x")`` footgun.
244+
245+
``str`` and ``bytes`` are both ``Sequence[Any]``, so they type-check
246+
but would silently split into N single-character parameters.
247+
"""
248+
if isinstance(params, str | bytes):
249+
raise TypeError("params must be a list or tuple, not str/bytes; did you mean [value]?")
250+
241251
async def _run_protocol[T](self, fn: Callable[[DqliteProtocol, int], Awaitable[T]]) -> T:
242252
"""Run a protocol operation with standard error handling.
243253
@@ -270,6 +280,7 @@ async def execute(self, sql: str, params: Sequence[Any] | None = None) -> tuple[
270280
271281
Returns (last_insert_id, rows_affected).
272282
"""
283+
self._validate_params(params)
273284
return await self._run_protocol(lambda p, db: p.exec_sql(db, sql, params))
274285

275286
async def query_raw(
@@ -281,6 +292,7 @@ async def query_raw(
281292
of (column_names, rows) from the wire protocol. Intended for
282293
DBAPI cursor implementations that need column names separately.
283294
"""
295+
self._validate_params(params)
284296
return await self._run_protocol(lambda p, db: p.query_sql(db, sql, params))
285297

286298
async def query_raw_typed(
@@ -296,11 +308,13 @@ async def query_raw_typed(
296308

297309
async def fetch(self, sql: str, params: Sequence[Any] | None = None) -> list[dict[str, Any]]:
298310
"""Execute a query and return results as list of dicts."""
311+
self._validate_params(params)
299312
columns, rows = await self._run_protocol(lambda p, db: p.query_sql(db, sql, params))
300313
return [dict(zip(columns, row, strict=True)) for row in rows]
301314

302315
async def fetchall(self, sql: str, params: Sequence[Any] | None = None) -> list[list[Any]]:
303316
"""Execute a query and return results as list of lists."""
317+
self._validate_params(params)
304318
_, rows = await self._run_protocol(lambda p, db: p.query_sql(db, sql, params))
305319
return rows
306320

@@ -318,6 +332,7 @@ async def fetchone(
318332

319333
async def fetchval(self, sql: str, params: Sequence[Any] | None = None) -> Any:
320334
"""Execute a query and return the first column of the first row."""
335+
self._validate_params(params)
321336
_, rows = await self._run_protocol(lambda p, db: p.query_sql(db, sql, params))
322337
if rows and rows[0]:
323338
return rows[0][0]

tests/test_connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,15 @@ async def test_commit_failure_invalidates_connection(self, connected_connection)
841841
pass
842842
assert not conn.is_connected, "failed COMMIT must invalidate the connection"
843843

844+
async def test_string_params_rejected_with_clear_error(self, connected_connection) -> None:
845+
"""Passing a bare string as params silently splits it into N character
846+
parameters. Guard at the client boundary so the user gets a clear
847+
TypeError instead of a confusing server-side "wrong parameter count".
848+
"""
849+
conn, _, _ = connected_connection
850+
with pytest.raises(TypeError, match="list or tuple"):
851+
await conn.execute("SELECT ?", "alice") # type: ignore[arg-type]
852+
844853
async def test_cross_event_loop_raises_interface_error(self) -> None:
845854
"""Using a connection from a different event loop must raise InterfaceError."""
846855
import asyncio

0 commit comments

Comments
 (0)