Skip to content

Commit e1729d1

Browse files
Accept injected ClusterClient or NodeStore on ConnectionPool
ConnectionPool constructed its own ClusterClient from raw addresses with no way to share one across pools. Multi-database apps, persistent node stores, and testability all suffered. Add optional cluster= and node_store= kwargs; keep addresses= the default path. When cluster= is provided, the pool reuses its node store and (future) leader cache. Validate that callers pass exactly one source. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e2bd29d commit e1729d1

3 files changed

Lines changed: 65 additions & 6 deletions

File tree

src/dqliteclient/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,27 @@ async def connect(
6363

6464

6565
async def create_pool(
66-
addresses: list[str],
66+
addresses: list[str] | None = None,
6767
*,
6868
database: str = "default",
6969
min_size: int = 1,
7070
max_size: int = 10,
7171
timeout: float = 10.0,
72+
cluster: ClusterClient | None = None,
73+
node_store: NodeStore | None = None,
7274
) -> ConnectionPool:
7375
"""Create a connection pool with automatic leader detection.
7476
7577
Args:
76-
addresses: List of node addresses in "host:port" format
78+
addresses: List of node addresses in "host:port" format. Ignored if
79+
``cluster`` or ``node_store`` is provided.
7780
database: Database name to open
7881
min_size: Minimum number of connections to maintain
7982
max_size: Maximum number of connections
8083
timeout: Connection timeout in seconds
84+
cluster: Externally-owned ClusterClient shared across pools.
85+
node_store: Externally-owned NodeStore used to build a new
86+
ClusterClient. Mutually exclusive with ``cluster``.
8187
8288
Returns:
8389
An initialized ConnectionPool
@@ -88,6 +94,8 @@ async def create_pool(
8894
min_size=min_size,
8995
max_size=max_size,
9096
timeout=timeout,
97+
cluster=cluster,
98+
node_store=node_store,
9199
)
92100
await pool.initialize()
93101
return pool

src/dqliteclient/pool.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dqliteclient.cluster import ClusterClient
1010
from dqliteclient.connection import DqliteConnection
1111
from dqliteclient.exceptions import DqliteConnectionError
12+
from dqliteclient.node_store import NodeStore
1213

1314

1415
def _socket_looks_dead(conn: DqliteConnection) -> bool:
@@ -48,21 +49,29 @@ class ConnectionPool:
4849

4950
def __init__(
5051
self,
51-
addresses: list[str],
52+
addresses: list[str] | None = None,
5253
*,
5354
database: str = "default",
5455
min_size: int = 1,
5556
max_size: int = 10,
5657
timeout: float = 10.0,
58+
cluster: ClusterClient | None = None,
59+
node_store: NodeStore | None = None,
5760
) -> None:
5861
"""Initialize connection pool.
5962
6063
Args:
61-
addresses: List of node addresses
64+
addresses: List of node addresses. Ignored if ``cluster`` or
65+
``node_store`` is provided; required otherwise.
6266
database: Database name
6367
min_size: Minimum connections to maintain
6468
max_size: Maximum connections allowed
6569
timeout: Connection timeout
70+
cluster: Externally-owned ClusterClient. Lets multiple pools
71+
share one ClusterClient (and thus its node store, leader
72+
cache, etc.) across databases or tenants.
73+
node_store: Externally-owned NodeStore used to build a new
74+
ClusterClient. Mutually exclusive with ``cluster``.
6675
"""
6776
if min_size < 0:
6877
raise ValueError(f"min_size must be non-negative, got {min_size}")
@@ -72,14 +81,23 @@ def __init__(
7281
raise ValueError(f"min_size ({min_size}) must not exceed max_size ({max_size})")
7382
if timeout <= 0:
7483
raise ValueError(f"timeout must be positive, got {timeout}")
84+
if cluster is not None and node_store is not None:
85+
raise ValueError("pass only one of cluster= or node_store=")
86+
if cluster is None and node_store is None and not addresses:
87+
raise ValueError("pass one of addresses, cluster, or node_store")
7588

76-
self._addresses = addresses
89+
self._addresses = addresses or []
7790
self._database = database
7891
self._min_size = min_size
7992
self._max_size = max_size
8093
self._timeout = timeout
8194

82-
self._cluster = ClusterClient.from_addresses(addresses, timeout=timeout)
95+
if cluster is not None:
96+
self._cluster = cluster
97+
elif node_store is not None:
98+
self._cluster = ClusterClient(node_store, timeout=timeout)
99+
else:
100+
self._cluster = ClusterClient.from_addresses(self._addresses, timeout=timeout)
83101
self._pool: asyncio.Queue[DqliteConnection] = asyncio.Queue(maxsize=max_size)
84102
self._size = 0
85103
self._lock = asyncio.Lock()

tests/test_pool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,39 @@ async def slow_connect(**kwargs):
873873

874874
await pool.close()
875875

876+
async def test_pool_accepts_injected_cluster_client(self) -> None:
877+
"""Callers must be able to share a ClusterClient across multiple pools."""
878+
from dqliteclient.cluster import ClusterClient
879+
880+
shared_cluster = ClusterClient.from_addresses(["localhost:9001"])
881+
pool_a = ConnectionPool(cluster=shared_cluster, min_size=0, max_size=3)
882+
pool_b = ConnectionPool(cluster=shared_cluster, min_size=0, max_size=3)
883+
assert pool_a._cluster is shared_cluster
884+
assert pool_b._cluster is shared_cluster
885+
886+
async def test_pool_accepts_injected_node_store(self) -> None:
887+
"""Callers with a persistent NodeStore must be able to thread it in."""
888+
from dqliteclient.node_store import MemoryNodeStore
889+
890+
store = MemoryNodeStore(["localhost:9001", "localhost:9002"])
891+
pool = ConnectionPool(node_store=store, min_size=0, max_size=1)
892+
nodes = await pool._cluster._node_store.get_nodes()
893+
assert [n.address for n in nodes] == ["localhost:9001", "localhost:9002"]
894+
895+
async def test_pool_requires_some_cluster_source(self) -> None:
896+
"""Constructing with neither addresses nor cluster/node_store must raise."""
897+
with pytest.raises(ValueError, match="addresses.*cluster.*node_store"):
898+
ConnectionPool()
899+
900+
async def test_pool_rejects_cluster_and_node_store_together(self) -> None:
901+
"""Passing both cluster= and node_store= must raise — pick one."""
902+
from dqliteclient.cluster import ClusterClient
903+
from dqliteclient.node_store import MemoryNodeStore
904+
905+
cluster = ClusterClient.from_addresses(["localhost:9001"])
906+
store = MemoryNodeStore(["localhost:9001"])
907+
with pytest.raises(ValueError, match="only one"):
908+
ConnectionPool(cluster=cluster, node_store=store)
876909

877910
async def test_reset_connection_returns_false_on_cancelled_error(self) -> None:
878911
"""_reset_connection must return False (not raise) when ROLLBACK is cancelled."""

0 commit comments

Comments
 (0)