Skip to content

Commit 3573a3e

Browse files
Validate role against NodeRole when decoding ServersResponse
The wire-layer NodeInfo.role was a raw uint64 round-tripped without validation — a misbehaving or tampered server could plant arbitrary integers in the cluster-topology response and every downstream caller had to remember to wrap with NodeRole(...). Tighten the type annotation to NodeRole and raise DecodeError (chained from ValueError) at the wire seam when the value is outside {VOTER, STANDBY, SPARE}. The error message lists the offending value and the valid set so the failure is actionable in logs. NodeRole is an IntEnum subclass of int, so encode output is byte-identical and downstream ``role == 0`` comparisons continue to work. Three new tests pin the rejection path, the round-trip type, and the int-equality guarantee. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 190de54 commit 3573a3e

2 files changed

Lines changed: 57 additions & 2 deletions

File tree

src/dqlitewire/messages/responses.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ROW_DONE_MARKER,
88
ROW_PART_MARKER,
99
WORD_SIZE,
10+
NodeRole,
1011
ResponseType,
1112
ValueType,
1213
)
@@ -559,7 +560,7 @@ class NodeInfo:
559560

560561
node_id: int
561562
address: str
562-
role: int
563+
role: NodeRole
563564

564565

565566
@dataclass
@@ -603,8 +604,15 @@ def decode_body(cls, data: bytes, schema: int = 0) -> "ServersResponse":
603604
offset += 8
604605
address, consumed = decode_text(view[offset:])
605606
offset += consumed
606-
role = decode_uint64(view[offset:])
607+
raw_role = decode_uint64(view[offset:])
607608
offset += 8
609+
try:
610+
role = NodeRole(raw_role)
611+
except ValueError as exc:
612+
valid = sorted(r.value for r in NodeRole)
613+
raise DecodeError(
614+
f"Invalid node role {raw_role} at offset {offset - 8}; expected one of {valid}"
615+
) from exc
608616
nodes.append(NodeInfo(node_id, address, role))
609617
return cls(nodes)
610618

tests/test_messages_responses.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,53 @@ def test_node_count_exceeds_hard_limit(self) -> None:
12211221
with pytest.raises(DecodeError, match="Node count.*exceeds maximum"):
12221222
ServersResponse.decode_body(body)
12231223

1224+
def test_decode_rejects_unknown_role(self) -> None:
1225+
"""An unknown ``role`` value must raise DecodeError at the wire seam.
1226+
1227+
Upstream C (``src/roles.c``) only ever emits VOTER/STANDBY/SPARE
1228+
(0/1/2). A server that sends anything else is either buggy or
1229+
hostile; either way we refuse to build a NodeInfo with an
1230+
unvalidated enum value — the failure must surface at the wire
1231+
boundary, not silently propagate.
1232+
"""
1233+
import pytest
1234+
1235+
from dqlitewire.exceptions import DecodeError
1236+
from dqlitewire.types import encode_text, encode_uint64
1237+
1238+
body = (
1239+
encode_uint64(1) # count = 1
1240+
+ encode_uint64(7) # node_id
1241+
+ encode_text("n1:9001")
1242+
+ encode_uint64(999) # role = invalid
1243+
)
1244+
with pytest.raises(DecodeError, match="Invalid node role 999"):
1245+
ServersResponse.decode_body(body)
1246+
1247+
def test_roundtrip_preserves_noderole_type(self) -> None:
1248+
"""Decoded role must be a NodeRole member, not a bare int."""
1249+
from dqlitewire.constants import NodeRole
1250+
1251+
nodes = [
1252+
NodeInfo(node_id=1, address="n1:9001", role=NodeRole.VOTER),
1253+
NodeInfo(node_id=2, address="n2:9002", role=NodeRole.STANDBY),
1254+
]
1255+
encoded = ServersResponse(nodes=nodes).encode()
1256+
decoded = ServersResponse.decode_body(encoded[HEADER_SIZE:])
1257+
assert [n.role for n in decoded.nodes] == [NodeRole.VOTER, NodeRole.STANDBY]
1258+
assert all(isinstance(n.role, NodeRole) for n in decoded.nodes)
1259+
1260+
def test_int_equality_survives_enum_typing(self) -> None:
1261+
"""Downstream code compares ``role == 0`` / ``== 1``; IntEnum
1262+
subclassing of int must keep those comparisons true even after
1263+
we tightened the type annotation.
1264+
"""
1265+
from dqlitewire.constants import NodeRole
1266+
1267+
info = NodeInfo(node_id=1, address="n1:9001", role=NodeRole.VOTER)
1268+
assert info.role == 0
1269+
assert info.role == NodeRole.VOTER
1270+
12241271

12251272
class TestMetadataResponse:
12261273
def test_roundtrip(self) -> None:

0 commit comments

Comments
 (0)