Skip to content

Commit 2a0dfc9

Browse files
feat: accept injected ClusterClient or NodeStore on ConnectionPool
ConnectionPool constructed its own ClusterClient from raw addresses with no way to share one across pools. Multi-database apps, persistent node stores, and testability all suffered. Add optional cluster= and node_store= kwargs; keep addresses= the default path. When cluster= is provided, the pool reuses its node store and (future) leader cache. Validate that callers pass exactly one source. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e9eeaf0 commit 2a0dfc9

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

src/dqliteclient/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,27 @@ async def connect(
6363

6464

6565
async def create_pool(
66-
addresses: list[str],
66+
addresses: list[str] | None = None,
6767
*,
6868
database: str = "default",
6969
min_size: int = 1,
7070
max_size: int = 10,
7171
timeout: float = 10.0,
72+
cluster: ClusterClient | None = None,
73+
node_store: NodeStore | None = None,
7274
) -> ConnectionPool:
7375
"""Create a connection pool with automatic leader detection.
7476
7577
Args:
76-
addresses: List of node addresses in "host:port" format
78+
addresses: List of node addresses in "host:port" format. Ignored if
79+
``cluster`` or ``node_store`` is provided.
7780
database: Database name to open
7881
min_size: Minimum number of connections to maintain
7982
max_size: Maximum number of connections
8083
timeout: Connection timeout in seconds
84+
cluster: Externally-owned ClusterClient shared across pools.
85+
node_store: Externally-owned NodeStore used to build a new
86+
ClusterClient. Mutually exclusive with ``cluster``.
8187
8288
Returns:
8389
An initialized ConnectionPool
@@ -88,6 +94,8 @@ async def create_pool(
8894
min_size=min_size,
8995
max_size=max_size,
9096
timeout=timeout,
97+
cluster=cluster,
98+
node_store=node_store,
9199
)
92100
await pool.initialize()
93101
return pool

src/dqliteclient/pool.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dqliteclient.cluster import ClusterClient
1010
from dqliteclient.connection import DqliteConnection
1111
from dqliteclient.exceptions import DqliteConnectionError
12+
from dqliteclient.node_store import NodeStore
1213

1314

1415
def _socket_looks_dead(conn: DqliteConnection) -> bool:
@@ -48,21 +49,29 @@ class ConnectionPool:
4849

4950
def __init__(
5051
self,
51-
addresses: list[str],
52+
addresses: list[str] | None = None,
5253
*,
5354
database: str = "default",
5455
min_size: int = 1,
5556
max_size: int = 10,
5657
timeout: float = 10.0,
58+
cluster: ClusterClient | None = None,
59+
node_store: NodeStore | None = None,
5760
) -> None:
5861
"""Initialize connection pool.
5962
6063
Args:
61-
addresses: List of node addresses
64+
addresses: List of node addresses. Ignored if ``cluster`` or
65+
``node_store`` is provided; required otherwise.
6266
database: Database name
6367
min_size: Minimum connections to maintain
6468
max_size: Maximum connections allowed
6569
timeout: Connection timeout
70+
cluster: Externally-owned ClusterClient. Lets multiple pools
71+
share one ClusterClient (and thus its node store, leader
72+
cache, etc.) across databases or tenants.
73+
node_store: Externally-owned NodeStore used to build a new
74+
ClusterClient. Mutually exclusive with ``cluster``.
6675
"""
6776
if min_size < 0:
6877
raise ValueError(f"min_size must be non-negative, got {min_size}")
@@ -72,14 +81,23 @@ def __init__(
7281
raise ValueError(f"min_size ({min_size}) must not exceed max_size ({max_size})")
7382
if timeout <= 0:
7483
raise ValueError(f"timeout must be positive, got {timeout}")
84+
if cluster is not None and node_store is not None:
85+
raise ValueError("pass only one of cluster= or node_store=")
86+
if cluster is None and node_store is None and not addresses:
87+
raise ValueError("pass one of addresses, cluster, or node_store")
7588

76-
self._addresses = addresses
89+
self._addresses = addresses or []
7790
self._database = database
7891
self._min_size = min_size
7992
self._max_size = max_size
8093
self._timeout = timeout
8194

82-
self._cluster = ClusterClient.from_addresses(addresses, timeout=timeout)
95+
if cluster is not None:
96+
self._cluster = cluster
97+
elif node_store is not None:
98+
self._cluster = ClusterClient(node_store, timeout=timeout)
99+
else:
100+
self._cluster = ClusterClient.from_addresses(self._addresses, timeout=timeout)
83101
self._pool: asyncio.Queue[DqliteConnection] = asyncio.Queue(maxsize=max_size)
84102
self._size = 0
85103
self._lock = asyncio.Lock()

tests/test_pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,39 @@ async def slow_connect(**kwargs):
873873

874874
await pool.close()
875875

876+
async def test_pool_accepts_injected_cluster_client(self) -> None:
877+
"""Callers must be able to share a ClusterClient across multiple pools."""
878+
from dqliteclient.cluster import ClusterClient
879+
880+
shared_cluster = ClusterClient.from_addresses(["localhost:9001"])
881+
pool_a = ConnectionPool(cluster=shared_cluster, min_size=0, max_size=3)
882+
pool_b = ConnectionPool(cluster=shared_cluster, min_size=0, max_size=3)
883+
assert pool_a._cluster is shared_cluster
884+
assert pool_b._cluster is shared_cluster
885+
886+
async def test_pool_accepts_injected_node_store(self) -> None:
887+
"""Callers with a persistent NodeStore must be able to thread it in."""
888+
from dqliteclient.node_store import MemoryNodeStore
889+
890+
store = MemoryNodeStore(["localhost:9001", "localhost:9002"])
891+
pool = ConnectionPool(node_store=store, min_size=0, max_size=1)
892+
nodes = await pool._cluster._node_store.get_nodes()
893+
assert [n.address for n in nodes] == ["localhost:9001", "localhost:9002"]
894+
895+
async def test_pool_requires_some_cluster_source(self) -> None:
896+
"""Constructing with neither addresses nor cluster/node_store must raise."""
897+
with pytest.raises(ValueError, match="addresses.*cluster.*node_store"):
898+
ConnectionPool()
899+
900+
async def test_pool_rejects_cluster_and_node_store_together(self) -> None:
901+
"""Passing both cluster= and node_store= must raise — pick one."""
902+
from dqliteclient.cluster import ClusterClient
903+
from dqliteclient.node_store import MemoryNodeStore
904+
905+
cluster = ClusterClient.from_addresses(["localhost:9001"])
906+
store = MemoryNodeStore(["localhost:9001"])
907+
with pytest.raises(ValueError, match="only one"):
908+
ConnectionPool(cluster=cluster, node_store=store)
876909

877910
async def test_reset_connection_returns_false_on_cancelled_error(self) -> None:
878911
"""_reset_connection must return False (not raise) when ROLLBACK is cancelled."""

0 commit comments

Comments
 (0)