Skip to content

Commit 78a9daf

Browse files
[BugFix, Feature] Fix KeyError in NXGraphCompiler and extending the class of embeddable graphs to unit disk graphs (#116)
* Added a method to check if BaseGraph is a unit disk graph * is_unit_disk_graph: return True if graph is fully connected and False if empty or disconnected * Testing is_unit_disk_graph * FIX: node_attrs is correctly initialized when using a networkx graph * Using is_unit_disk_graph rather than is_disk_graph * Update qek/data/graphs.py Co-authored-by: RolandMacDoland <9250798+RolandMacDoland@users.noreply.github.com> * Update tests/test_graphs.py Co-authored-by: RolandMacDoland <9250798+RolandMacDoland@users.noreply.github.com> --------- Co-authored-by: RolandMacDoland <9250798+RolandMacDoland@users.noreply.github.com>
1 parent ddc8cee commit 78a9daf

2 files changed

Lines changed: 177 additions & 3 deletions

File tree

qek/data/graphs.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,53 @@ def __init__(
6767
# The graph in networkx format, undirected.
6868
self.nx_graph: nx.Graph = pyg_utils.to_networkx(
6969
data=data,
70-
node_attrs=["x"],
70+
node_attrs=["x"] if data.x is not None else None,
7171
edge_attrs=["edge_attr"] if data.edge_attr is not None else None,
7272
to_undirected=True,
7373
)
7474
self.target = target
7575
self.id = id
7676

77+
def is_unit_disk_graph(self) -> bool:
78+
"""
79+
A predicate to check if `self` is a unit disk graph.
80+
81+
Returns:
82+
`True` if the graph is a unit disk graph.
83+
`False` otherwise.
84+
"""
85+
86+
if self.pyg.num_nodes == 0 or self.pyg.num_nodes is None:
87+
logger.debug("graph %s doesn't have any nodes, it's not a disk graph", self.id, self.id)
88+
return False
89+
90+
# Check if the graph is connected.
91+
if len(self.nx_graph) == 0 or not nx.is_connected(self.nx_graph):
92+
logger.debug("graph %s is not connected, it's not a disk graph", self.id)
93+
return False
94+
95+
# Check the distances between all pairs of nodes.
96+
pos = self.pyg.pos
97+
assert pos is not None
98+
99+
non_connected_distances_um = [
100+
np.linalg.norm(np.array(pos[u]) - np.array(pos[v]))
101+
for u, v in nx.non_edges(self.nx_graph)
102+
]
103+
104+
# Fully connected graphs are always unit disk graphs
105+
if len(non_connected_distances_um) == 0:
106+
return True
107+
108+
connected_distances_um = [
109+
np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) for u, v in self.nx_graph.edges()
110+
]
111+
112+
if min(non_connected_distances_um) < max(connected_distances_um):
113+
return False
114+
115+
return True
116+
77117
def is_disk_graph(self, radius: float) -> bool:
78118
"""
79119
A predicate to check if `self` is a disk graph with the specified
@@ -175,8 +215,10 @@ def is_embeddable(self) -> bool:
175215
return False
176216

177217
# Check distance between nodes
178-
if not self.is_disk_graph(self.device.min_atom_distance + EPSILON_RADIUS_UM):
179-
logger.debug("graph %s is not a disk graph, it's not embeddable", self.id)
218+
if not self.is_unit_disk_graph():
219+
logger.debug(
220+
"graph %s is not a unit disk graph, therefore it's not embeddable", self.id
221+
)
180222
return False
181223

182224
for u, v in self.nx_graph.edges():

tests/test_graphs.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,126 @@ def test_graph_init() -> None:
4949
assert graph.is_disk_graph(pulser.AnalogDevice.min_atom_distance + 0.01)
5050

5151

52+
def test_is_unit_disk_graph_false() -> None:
53+
"""Testing is_unit_disk_graph: these graphs are *not* unit disk graphs"""
54+
55+
# The empty graph is not a disk graph.
56+
graph_empty = BaseGraph(
57+
id=0,
58+
data=Data(
59+
x=torch.tensor([], dtype=torch.float),
60+
edge_index=torch.tensor([], dtype=torch.int),
61+
pos=torch.tensor([], dtype=torch.float),
62+
y=1,
63+
),
64+
device=pulser.AnalogDevice,
65+
)
66+
assert not graph_empty.is_unit_disk_graph()
67+
68+
# This graph has three nodes, each pair of nodes is closer than
69+
# the diameter, but it's not a disk graph because one of the nodes
70+
# is not connected.
71+
graph_disconnected_close = BaseGraph(
72+
id=0,
73+
data=Data(
74+
x=torch.tensor([[0], [1], [2]], dtype=torch.float),
75+
edge_index=torch.tensor(
76+
[
77+
[0, 1], # edge 0 -> 1
78+
[1, 0], # edge 1 -> 0
79+
],
80+
dtype=torch.int,
81+
),
82+
pos=torch.tensor([[0], [1], [2]], dtype=torch.float),
83+
y=1,
84+
),
85+
device=pulser.AnalogDevice,
86+
)
87+
assert not graph_disconnected_close.is_unit_disk_graph()
88+
89+
# This graph has three nodes, two nodes are connected, but it's
90+
# not a disk graph because the non connected node distance is shorter than the connected edges
91+
graph_non_udg = BaseGraph(
92+
id=0,
93+
data=Data(
94+
x=torch.tensor([[0], [1], [2]], dtype=torch.float),
95+
edge_index=torch.tensor(
96+
[
97+
[
98+
0,
99+
1, # edge 0 -> 1
100+
1,
101+
2, # edge 1 -> 2
102+
],
103+
[
104+
1,
105+
0, # edge 1 -> 0
106+
2,
107+
1, # edge 2 -> 1
108+
],
109+
],
110+
dtype=torch.int,
111+
),
112+
pos=torch.tensor([[0, 0], [0, 1], [0.1, 0]], dtype=torch.float),
113+
y=1,
114+
),
115+
device=pulser.AnalogDevice,
116+
)
117+
assert not graph_non_udg.is_unit_disk_graph()
118+
119+
120+
def test_is_unit_disk_graph_true() -> None:
121+
"""
122+
Testing is_disk_graph: these graphs are unit disk graphs
123+
"""
124+
# Single node graph is always a unit disk graph.
125+
graph_single_node = BaseGraph(
126+
id=0,
127+
data=Data(
128+
x=torch.tensor([0], dtype=torch.float),
129+
edge_index=torch.tensor([]),
130+
pos=torch.tensor([], dtype=torch.float),
131+
y=1,
132+
),
133+
device=pulser.AnalogDevice,
134+
)
135+
assert graph_single_node.is_unit_disk_graph()
136+
137+
# A complete graph with three nodes, each of the edges
138+
# is shorter than the disk's diameter.
139+
graph_connected_close = BaseGraph(
140+
id=0,
141+
data=Data(
142+
x=torch.tensor([[0], [1], [2]], dtype=torch.float),
143+
edge_index=torch.tensor(
144+
[
145+
[
146+
0,
147+
1, # edge 0 -> 1
148+
1,
149+
2, # edge 1 -> 2
150+
0,
151+
2, # edge 0 -> 2
152+
],
153+
[
154+
1,
155+
0, # edge 1 -> 0
156+
2,
157+
1, # edge 2 -> 1
158+
2,
159+
0, # edge 2 -> 0
160+
],
161+
],
162+
dtype=torch.int,
163+
),
164+
pos=torch.tensor([[0], [1], [2]], dtype=torch.float),
165+
y=1,
166+
),
167+
device=pulser.AnalogDevice,
168+
)
169+
assert graph_connected_close.is_unit_disk_graph()
170+
171+
52172
def test_is_disk_graph_false() -> None:
53173
"""
54174
Testing is_disk_graph: these graphs are *not* disk graphs
@@ -269,3 +389,15 @@ def test_basic_compile() -> None:
269389
assert isinstance(sample, Data)
270390
ingested = molecule_graph_compiler.ingest(graph=sample, device=pulser.AnalogDevice, id=9)
271391
assert ingested.id == 9
392+
393+
394+
def test_NXGraphCompiler_ingest_from_networkx() -> None:
395+
"""Testing the NXGraphCompiler ingest starting from a networkx graph"""
396+
397+
nx_graph = nx.Graph()
398+
nx_graph.add_edges_from([(1, 2), (2, 3)])
399+
nx.set_node_attributes(nx_graph, {1: [0, 0], 2: [5.02, 0], 3: [11.02, 0]}, "pos")
400+
nx_with_pos = NXWithPos(nx_graph, positions=nx.get_node_attributes(nx_graph, "pos"), target=0)
401+
compiler = NXGraphCompiler()
402+
graph = compiler.ingest(nx_with_pos, device=pulser.AnalogDevice, id=0)
403+
assert graph.is_unit_disk_graph()

0 commit comments

Comments
 (0)