Skip to content

Commit 65417ca

Browse files
Short-circuit close() / initialize() in forked child to spare inherited FDs
The first fork-guard commit covered the use-after-fork case (acquire, _check_in_use), but close() and pool.initialize() were left to fail late. In a child, the inherited connection FDs are shared with the parent: writer.close() in DqliteConnection.close() or the per-conn drain inside ConnectionPool.close() would send FIN on sockets the parent still uses, breaking the parent's connections from the child's GC sweep. Pid-check at the top of: - DqliteConnection.close: flip _protocol/_db_id to None and bail — no wire teardown, the FD stays open in the child for the parent. - ConnectionPool.close: flip _closed=True and bail — no _drain_idle. - ConnectionPool.initialize: raise InterfaceError up-front so a child cannot kick off TCP work against asyncio primitives bound to the parent's loop. Tests cover all three paths plus the existing acquire/check_in_use short-circuits. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 644c530 commit 65417ca

3 files changed

Lines changed: 119 additions & 0 deletions

File tree

src/dqliteclient/connection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,19 @@ async def close(self) -> None:
10871087
# close path has already run under pool ownership.
10881088
if self._pool_released:
10891089
return
1090+
# Fork-after-init: the inherited socket FD is shared with the
1091+
# parent. ``writer.close()`` would send a FIN that the parent
1092+
# still depends on. Flip the local state to closed without
1093+
# touching the wire so the child can clean up its references
1094+
# quietly, then bail. The pid-aware ``_check_in_use`` would
1095+
# also raise here, but close() is documented as idempotent and
1096+
# silent on already-closed inputs — silently no-oping in the
1097+
# child preserves that contract for the GC / __del__ path that
1098+
# commonly drives close in a forked worker.
1099+
if os.getpid() != self._creator_pid:
1100+
self._protocol = None
1101+
self._db_id = None
1102+
return
10901103
# Run the in-use guard BEFORE the ``_protocol is None``
10911104
# early-return so a concurrent ``connect()`` racing with
10921105
# ``close()`` surfaces as ``InterfaceError`` instead of a silent

src/dqliteclient/pool.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ async def initialize(self) -> None:
316316
``return_exceptions=True`` so every task resolves, then close
317317
survivors explicitly before re-raising the first failure.
318318
"""
319+
if os.getpid() != self._creator_pid:
320+
raise InterfaceError(
321+
"Pool used after fork; reconstruct from configuration "
322+
"in the target process."
323+
)
319324
# Hold the lock across the gather so a second concurrent
320325
# initialize() call observes _initialized=True after the first
321326
# completes and returns without re-creating.
@@ -1307,6 +1312,16 @@ async def close(self) -> None:
13071312
if self._close_done is not None:
13081313
await self._close_done.wait()
13091314
return
1315+
# Fork-after-init: the inherited connection FDs are shared with
1316+
# the parent. Draining and writer.close() in the child would
1317+
# send FIN on sockets the parent still uses. Flip the closed
1318+
# flag so the child's references can be GC'd quietly without
1319+
# touching the wire. The child cannot acquire new connections
1320+
# either way (pid-aware ``acquire`` rejects). Symmetric with
1321+
# ``DqliteConnection.close``'s fork short-circuit.
1322+
if os.getpid() != self._creator_pid:
1323+
self._closed = True
1324+
return
13101325
# Publish the drain-done event BEFORE flipping the closed flag
13111326
# so any second caller observing ``_closed=True`` is guaranteed
13121327
# to see a valid ``_close_done`` to wait on. Under single-task

tests/test_connection_after_fork_raises.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,94 @@ async def run() -> None:
9191

9292
result = _run_in_child(child_check)
9393
assert result == b"OK", f"child reported: {result!r}"
94+
95+
96+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="requires os.fork")
97+
def test_connection_pool_initialize_after_fork_raises_interface_error() -> None:
98+
pool = ConnectionPool(addresses=["127.0.0.1:9999"])
99+
100+
def child_check() -> None:
101+
async def run() -> None:
102+
await pool.initialize()
103+
104+
asyncio.run(run())
105+
106+
result = _run_in_child(child_check)
107+
assert result == b"OK", f"child reported: {result!r}"
108+
109+
110+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="requires os.fork")
111+
def test_dqlite_connection_close_after_fork_short_circuits() -> None:
112+
"""``close()`` in the child must not touch the inherited socket
113+
(which is shared with the parent — sending FIN would close it
114+
for the parent too). Short-circuits to a quiet local-state flip."""
115+
conn = DqliteConnection("127.0.0.1:9999")
116+
117+
r, w = os.pipe()
118+
pid = os.fork()
119+
if pid == 0:
120+
try:
121+
os.close(r)
122+
try:
123+
124+
async def run() -> None:
125+
await conn.close()
126+
127+
asyncio.run(run())
128+
# close() did not raise — the child's local state is
129+
# marked closed without touching the wire.
130+
os.write(w, b"OK")
131+
except Exception as e: # noqa: BLE001
132+
os.write(w, f"WRONG:{type(e).__name__}:{e}".encode())
133+
finally:
134+
os.close(w)
135+
finally:
136+
os._exit(0)
137+
os.close(w)
138+
result = b""
139+
while True:
140+
chunk = os.read(r, 4096)
141+
if not chunk:
142+
break
143+
result += chunk
144+
os.close(r)
145+
os.waitpid(pid, 0)
146+
assert result == b"OK", f"child reported: {result!r}"
147+
148+
149+
@pytest.mark.skipif(not hasattr(os, "fork"), reason="requires os.fork")
150+
def test_connection_pool_close_after_fork_short_circuits() -> None:
151+
"""``pool.close()`` in the child must not drain inherited connection
152+
FDs — those would close sockets the parent still uses. Short-
153+
circuits to a quiet local-state flip."""
154+
pool = ConnectionPool(addresses=["127.0.0.1:9999"])
155+
156+
r, w = os.pipe()
157+
pid = os.fork()
158+
if pid == 0:
159+
try:
160+
os.close(r)
161+
try:
162+
163+
async def run() -> None:
164+
await pool.close()
165+
assert pool._closed is True
166+
167+
asyncio.run(run())
168+
os.write(w, b"OK")
169+
except Exception as e: # noqa: BLE001
170+
os.write(w, f"WRONG:{type(e).__name__}:{e}".encode())
171+
finally:
172+
os.close(w)
173+
finally:
174+
os._exit(0)
175+
os.close(w)
176+
result = b""
177+
while True:
178+
chunk = os.read(r, 4096)
179+
if not chunk:
180+
break
181+
result += chunk
182+
os.close(r)
183+
os.waitpid(pid, 0)
184+
assert result == b"OK", f"child reported: {result!r}"

0 commit comments

Comments
 (0)