Skip to content

Commit 18c3316

Browse files
Reject use of connection / pool from a forked child process
Fork-after-init is unsupported: the inherited TCP socket is shared with the parent so writes interleave on the wire, and the connection's asyncio primitives are bound to a loop the child cannot drive. Without a guard the child silently corrupts the wire or hangs forever. Record os.getpid() in DqliteConnection.__init__ and ConnectionPool.__init__ and raise InterfaceError("Connection used after fork; reconstruct from configuration in the target process") from _check_in_use / acquire when the running process pid no longer matches the creator pid. Symmetric with the existing __reduce__ pickle/copy guards. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4d786ca commit 18c3316

4 files changed

Lines changed: 125 additions & 0 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ipaddress
66
import logging
77
import math
8+
import os
89
import re
910
import string
1011
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping, Sequence
@@ -866,6 +867,13 @@ def __init__(
866867
# ``_invalidate`` so a subsequent ``close()`` can await it and
867868
# keep the reader task from outliving the connection.
868869
self._pending_drain: asyncio.Task[None] | None = None
870+
# Fork-after-init is unsupported: the inherited TCP socket
871+
# would be shared with the parent and writes would interleave
872+
# on the wire, and asyncio primitives bound to the parent's
873+
# loop are unusable in the child. Store the creator pid so
874+
# cross-fork use raises a clear ``InterfaceError`` from any
875+
# public method instead of silent corruption.
876+
self._creator_pid = os.getpid()
869877

870878
@property
871879
def address(self) -> str:
@@ -1237,6 +1245,11 @@ def _ensure_connected(self) -> tuple[DqliteProtocol, int]:
12371245

12381246
def _check_in_use(self) -> None:
12391247
"""Raise on misuse: wrong event loop, concurrent access, or use after pool release."""
1248+
if os.getpid() != self._creator_pid:
1249+
raise InterfaceError(
1250+
"Connection used after fork; reconstruct from configuration "
1251+
"in the target process."
1252+
)
12401253
if self._pool_released:
12411254
raise InterfaceError(
12421255
"This connection has been returned to the pool and can no longer "

src/dqliteclient/pool.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import contextlib
55
import logging
6+
import os
67
from collections.abc import AsyncIterator, Sequence
78
from contextlib import asynccontextmanager
89
from types import TracebackType
@@ -265,6 +266,13 @@ def __init__(
265266
self._closed_event: asyncio.Event | None = None
266267
self._close_done: asyncio.Event | None = None
267268
self._initialized = False
269+
# Fork-after-init is unsupported: pooled connections hold
270+
# shared TCP sockets and asyncio primitives bound to the
271+
# parent's loop. Store the creator pid so cross-fork
272+
# ``acquire`` raises a clear ``InterfaceError`` instead of
273+
# silently corrupting the wire by interleaving writes.
274+
# Symmetric with ``__reduce__`` and the per-connection guard.
275+
self._creator_pid = os.getpid()
268276

269277
def __repr__(self) -> str:
270278
state = "closed" if self._closed else "open"
@@ -656,6 +664,11 @@ async def _drain_remaining_after_cancel(self) -> None:
656664
@asynccontextmanager
657665
async def acquire(self) -> AsyncIterator[DqliteConnection]:
658666
"""Acquire a connection from the pool."""
667+
if os.getpid() != self._creator_pid:
668+
raise InterfaceError(
669+
"Pool used after fork; reconstruct from configuration "
670+
"in the target process."
671+
)
659672
if self._closed:
660673
raise DqliteConnectionError("Pool is closed")
661674

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Pin: ``DqliteConnection`` and ``ConnectionPool`` raise
2+
``InterfaceError`` if used from a child process after ``os.fork``.
3+
4+
Fork-after-init is unsupported: the inherited TCP socket would be
5+
shared with the parent (writes interleaving on the wire), and asyncio
6+
primitives bound to the parent's loop are unusable in the child.
7+
8+
The fix records ``os.getpid()`` in ``__init__`` (both classes) and
9+
``_check_in_use`` / ``acquire`` raise a clear ``InterfaceError``
10+
("reconstruct from configuration in the target process") on pid
11+
mismatch — symmetric with the existing ``__reduce__`` pickle guards.
12+
13+
The test does not need a live server: the pid check fires before any
14+
network work; an unconnected instance is sufficient.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import asyncio
20+
import os
21+
22+
import pytest
23+
24+
from dqliteclient import DqliteConnection
25+
from dqliteclient.exceptions import InterfaceError
26+
from dqliteclient.pool import ConnectionPool
27+
28+
29+
def _run_in_child(check) -> bytes:
30+
r, w = os.pipe()
31+
pid = os.fork()
32+
if pid == 0:
33+
try:
34+
os.close(r)
35+
try:
36+
check()
37+
os.write(w, b"NO_RAISE")
38+
except InterfaceError as e:
39+
msg = str(e)
40+
if "fork" in msg and "reconstruct from configuration" in msg:
41+
os.write(w, b"OK")
42+
else:
43+
os.write(w, f"WRONG_MSG:{msg}".encode())
44+
except Exception as e: # noqa: BLE001
45+
os.write(w, f"WRONG_TYPE:{type(e).__name__}:{e}".encode())
46+
finally:
47+
os.close(w)
48+
finally:
49+
os._exit(0)
50+
os.close(w)
51+
result = b""
52+
while True:
53+
chunk = os.read(r, 4096)
54+
if not chunk:
55+
break
56+
result += chunk
57+
os.close(r)
58+
os.waitpid(pid, 0)
59+
return result
60+
61+
62+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="requires os.fork")
63+
def test_dqlite_connection_used_after_fork_raises_interface_error() -> None:
64+
conn = DqliteConnection("127.0.0.1:9999")
65+
assert conn._creator_pid == os.getpid()
66+
67+
def child_check() -> None:
68+
# ``_check_in_use`` runs ``asyncio.get_running_loop`` after
69+
# the pid check. Drive a loop so the pid mismatch is the
70+
# raised error, not the "must be used from async context" one.
71+
async def run() -> None:
72+
conn._check_in_use()
73+
74+
asyncio.run(run())
75+
76+
result = _run_in_child(child_check)
77+
assert result == b"OK", f"child reported: {result!r}"
78+
79+
80+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="requires os.fork")
81+
def test_connection_pool_acquire_after_fork_raises_interface_error() -> None:
82+
pool = ConnectionPool(addresses=["127.0.0.1:9999"])
83+
assert pool._creator_pid == os.getpid()
84+
85+
def child_check() -> None:
86+
async def run() -> None:
87+
async with pool.acquire():
88+
pass
89+
90+
asyncio.run(run())
91+
92+
result = _run_in_child(child_check)
93+
assert result == b"OK", f"child reported: {result!r}"

tests/test_interface_error_messages.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@
3131
def _make_bound_connection() -> DqliteConnection:
3232
"""A connection in the post-bind state on the running loop, with
3333
every other ``_check_in_use`` precondition relaxed."""
34+
import os as _os
35+
3436
conn = DqliteConnection.__new__(DqliteConnection)
3537
conn._pool_released = False
3638
conn._bound_loop = asyncio.get_running_loop()
3739
conn._in_use = False
3840
conn._in_transaction = False
3941
conn._tx_owner = None
42+
conn._creator_pid = _os.getpid()
4043
return conn
4144

4245

@@ -65,12 +68,15 @@ def test_called_from_sync_context_branch_message_substring() -> None:
6568
message rather than crashing on ``get_running_loop``'s
6669
RuntimeError. Run as a SYNC test so the surrounding code does
6770
not have a running loop."""
71+
import os as _os
72+
6873
conn = DqliteConnection.__new__(DqliteConnection)
6974
conn._pool_released = False
7075
conn._bound_loop = None
7176
conn._in_use = False
7277
conn._in_transaction = False
7378
conn._tx_owner = None
79+
conn._creator_pid = _os.getpid()
7480
with pytest.raises(InterfaceError, match="from within an async context"):
7581
conn._check_in_use()
7682

0 commit comments

Comments
 (0)