|
1 | | -from __future__ import annotations |
2 | | - |
3 | | -from typing import Final |
4 | | - |
5 | | -import networkx as nx |
6 | | -import numpy as np |
7 | | -import pulser as pl |
8 | | -import rdkit.Chem as Chem |
9 | 1 | import torch |
10 | 2 | import torch.utils.data as torch_data |
11 | | -import torch_geometric.data as pyg_data |
12 | | -import torch_geometric.utils as pyg_utils |
13 | | -from rdkit.Chem import AllChem |
14 | | - |
15 | | -from qek.utils import graph_to_mol |
16 | 3 |
|
17 | 4 |
|
18 | 5 | def split_train_test( |
@@ -40,258 +27,3 @@ def split_train_test( |
40 | 27 | generator = torch.Generator() |
41 | 28 | train, val = torch_data.random_split(dataset=dataset, lengths=lengths, generator=generator) |
42 | 29 | return train, val |
43 | | - |
44 | | - |
45 | | -EPSILON_DISTANCE_UM = 0.01 |
46 | | - |
47 | | - |
48 | | -class BaseGraph: |
49 | | - """ |
50 | | - A graph being prepared for embedding on a quantum device. |
51 | | - """ |
52 | | - |
53 | | - device: Final[pl.devices.Device] |
54 | | - |
55 | | - def __init__(self, data: pyg_data.Data, device: pl.devices.Device): |
56 | | - """ |
57 | | - Create a graph from geometric data. |
58 | | -
|
59 | | - Args: |
60 | | - data: A homogeneous graph, in PyTorch Geometric format. Unchanged. |
61 | | - It MUST have attributes 'pos' |
62 | | - device: The device for which the graph is prepared. |
63 | | - """ |
64 | | - if not hasattr(data, "pos"): |
65 | | - raise AttributeError("The graph should have an attribute 'pos'.") |
66 | | - |
67 | | - # The device for which the graph is prepared. |
68 | | - self.device = device |
69 | | - |
70 | | - # The graph in torch geometric format. |
71 | | - self.pyg = data.clone() |
72 | | - |
73 | | - # The graph in networkx format, undirected. |
74 | | - self.nx_graph = pyg_utils.to_networkx( |
75 | | - data=data, |
76 | | - node_attrs=["x"], |
77 | | - edge_attrs=["edge_attr"] if data.edge_attr is not None else None, |
78 | | - to_undirected=True, |
79 | | - ) |
80 | | - |
81 | | - def is_disk_graph(self, radius: float) -> bool: |
82 | | - """ |
83 | | - A predicate to check if `self` is a disk graph with the specified |
84 | | - radius, i.e. `self` is a connected graph and, for every pair of nodes |
85 | | - `A` and `B` within `graph`, there exists an edge between `A` and `B` |
86 | | - if and only if the positions of `A` and `B` within `self` are such |
87 | | - that `|AB| <= radius`. |
88 | | -
|
89 | | - Args: |
90 | | - radius: The maximal distance between two nodes of `self` |
91 | | - connected be an edge. |
92 | | -
|
93 | | - Returns: |
94 | | - `True` if the graph is a disk graph with the specified radius, |
95 | | - `False` otherwise. |
96 | | - """ |
97 | | - |
98 | | - if self.pyg.num_nodes == 0 or self.pyg.num_nodes is None: |
99 | | - return False |
100 | | - |
101 | | - # Check if the graph is connected. |
102 | | - if len(self.nx_graph) == 0 or not nx.is_connected(self.nx_graph): |
103 | | - return False |
104 | | - |
105 | | - pos = self.pyg.pos |
106 | | - assert pos is not None |
107 | | - |
108 | | - # Check the distances between all pairs of nodes. |
109 | | - for u, v in nx.non_edges(self.nx_graph): |
110 | | - distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) |
111 | | - if distance_um <= radius: |
112 | | - # These disjointed nodes would interact with each other, so |
113 | | - # this is not an embeddable graph. |
114 | | - return False |
115 | | - |
116 | | - for u, v in self.nx_graph.edges(): |
117 | | - distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) |
118 | | - if distance_um > radius: |
119 | | - # These joined nodes would not interact with each other, so |
120 | | - # this is not an embeddable graph. |
121 | | - return False |
122 | | - |
123 | | - return True |
124 | | - |
125 | | - def is_embeddable(self) -> bool: |
126 | | - """ |
127 | | - A predicate to check if the graph can be embedded in the |
128 | | - quantum device. |
129 | | -
|
130 | | - For a graph to be embeddable on a device, all the following |
131 | | - criteria must be fulfilled: |
132 | | - - the graph must be non-empty; |
133 | | - - the device must have at least as many atoms as the graph has |
134 | | - nodes; |
135 | | - - the device must be physically large enough to place all the |
136 | | - nodes (device.max_radial_distance); |
137 | | - - the nodes must be distant enough that quantum interactions |
138 | | - may take place (device.min_atom_distance) |
139 | | -
|
140 | | - Returns: |
141 | | - bool: True if possible, False if not |
142 | | - """ |
143 | | - |
144 | | - # Reject empty graphs. |
145 | | - if self.pyg.num_nodes == 0 or self.pyg.num_nodes is None: |
146 | | - return False |
147 | | - |
148 | | - # Reject graphs that have more nodes than can be represented |
149 | | - # on the device. |
150 | | - if self.pyg.num_nodes > self.device.max_atom_num: |
151 | | - return False |
152 | | - |
153 | | - # Check the distance from the center |
154 | | - |
155 | | - pos = self.pyg.pos |
156 | | - assert pos is not None |
157 | | - distance_from_center = np.linalg.norm(pos, ord=2, axis=-1) |
158 | | - if any(distance_from_center > self.device.max_radial_distance): |
159 | | - return False |
160 | | - |
161 | | - # Check distance between nodes |
162 | | - if not self.is_disk_graph(self.device.min_atom_distance + EPSILON_DISTANCE_UM): |
163 | | - return False |
164 | | - |
165 | | - for u, v in self.nx_graph.edges(): |
166 | | - distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v])) |
167 | | - if distance_um < self.device.min_atom_distance: |
168 | | - # These nodes are too close to each other, preventing quantum |
169 | | - # interactions on the device. |
170 | | - return False |
171 | | - |
172 | | - return True |
173 | | - |
174 | | - def compute_register(self) -> pl.Register: |
175 | | - """Create a Quantum Register based on a graph. |
176 | | -
|
177 | | - Returns: |
178 | | - pulser.Register: register |
179 | | - """ |
180 | | - pos = self.pyg.pos |
181 | | - assert pos is not None |
182 | | - return pl.Register.from_coordinates(coords=pos) |
183 | | - |
184 | | - def compute_sequence(self) -> pl.Sequence: |
185 | | - """ |
186 | | - Compile a Quantum Sequence from a graph for a specific device. |
187 | | -
|
188 | | - Raises: |
189 | | - ValueError if the graph cannot be embedded on the given device. |
190 | | - """ |
191 | | - if not self.is_embeddable(): |
192 | | - raise ValueError(f"The graph is not compatible with {self.device}") |
193 | | - reg = self.compute_register() |
194 | | - if self.device.requires_layout: |
195 | | - reg = reg.with_automatic_layout(device=self.device) |
196 | | - |
197 | | - seq = pl.Sequence(register=reg, device=self.device) |
198 | | - |
199 | | - # See the companion paper for an explanation on these constants. |
200 | | - Omega_max = 1.0 * 2 * np.pi |
201 | | - t_max = 660 |
202 | | - pulse = pl.Pulse.ConstantAmplitude( |
203 | | - amplitude=Omega_max, |
204 | | - detuning=pl.waveforms.RampWaveform(t_max, 0, 0), |
205 | | - phase=0.0, |
206 | | - ) |
207 | | - seq.declare_channel("ising", "rydberg_global") |
208 | | - seq.add(pulse, "ising") |
209 | | - return seq |
210 | | - |
211 | | - |
212 | | -class MoleculeGraph(BaseGraph): |
213 | | - """ |
214 | | - A graph based on molecular data, being prepared for embedding on a |
215 | | - quantum device. |
216 | | - """ |
217 | | - |
218 | | - # Constants used to decode the PTC-FM dataset, mapping |
219 | | - # integers (used as node attributes) to atom names. |
220 | | - PTCFM_ATOM_NAMES: Final[dict[int, str]] = { |
221 | | - 0: "In", |
222 | | - 1: "P", |
223 | | - 2: "C", |
224 | | - 3: "O", |
225 | | - 4: "N", |
226 | | - 5: "Cl", |
227 | | - 6: "S", |
228 | | - 7: "Br", |
229 | | - 8: "Na", |
230 | | - 9: "F", |
231 | | - 10: "As", |
232 | | - 11: "K", |
233 | | - 12: "Cu", |
234 | | - 13: "I", |
235 | | - 14: "Ba", |
236 | | - 15: "Sn", |
237 | | - 16: "Pb", |
238 | | - 17: "Ca", |
239 | | - } |
240 | | - |
241 | | - # Constants used to decode the PTC-FM dataset, mapping |
242 | | - # integers (used as edge attributes) to bond types. |
243 | | - PTCFM_BOND_TYPES: Final[dict[int, Chem.BondType]] = { |
244 | | - 0: Chem.BondType.TRIPLE, |
245 | | - 1: Chem.BondType.SINGLE, |
246 | | - 2: Chem.BondType.DOUBLE, |
247 | | - 3: Chem.BondType.AROMATIC, |
248 | | - } |
249 | | - |
250 | | - def __init__( |
251 | | - self, |
252 | | - data: pyg_data.Data, |
253 | | - device: pl.devices.Device, |
254 | | - node_mapping: dict[int, str] = PTCFM_ATOM_NAMES, |
255 | | - edge_mapping: dict[int, Chem.BondType] = PTCFM_BOND_TYPES, |
256 | | - ): |
257 | | - """ |
258 | | - Compute the geometry for a molecule graph. |
259 | | -
|
260 | | - Args: |
261 | | - data: A homogeneous graph, in PyTorch Geometric format. Unchanged. |
262 | | - blockade_radius: The radius of the Rydberg Blockade. Two |
263 | | - connected nodes should be at a distance < blockade_radius, |
264 | | - while two disconnected nodes should be at a |
265 | | - distance > blockade_radius. |
266 | | - node_mapping: A mapping of node labels from numbers to strings, |
267 | | - e.g. `5 => "Cl"`. Used when building molecules, e.g. to compute |
268 | | - distances between nodes. |
269 | | - edge_mapping: A mapping of edge labels from number to chemical |
270 | | - bond types, e.g. `2 => DOUBLE`. Used when building molecules, |
271 | | - e.g. to compute distances between nodes. |
272 | | - """ |
273 | | - pyg = data.clone() |
274 | | - pyg.pos = None # Placeholder |
275 | | - super().__init__(pyg, device) |
276 | | - |
277 | | - # Reconstruct the molecule. |
278 | | - tmp_mol = graph_to_mol( |
279 | | - graph=self.nx_graph, |
280 | | - node_mapping=node_mapping, |
281 | | - edge_mapping=edge_mapping, |
282 | | - ) |
283 | | - |
284 | | - # Extract the geometry. |
285 | | - AllChem.Compute2DCoords(tmp_mol, useRingTemplates=True) |
286 | | - pos = tmp_mol.GetConformer().GetPositions()[..., :2] # Convert to 2D |
287 | | - |
288 | | - # Scale the geometry so that the longest edge is as long as |
289 | | - # `device.min_atom_distance`. |
290 | | - dist_list = [] |
291 | | - for start, end in self.nx_graph.edges(): |
292 | | - dist_list.append(np.linalg.norm(pos[start] - pos[end])) |
293 | | - norm_factor = np.max(dist_list) |
294 | | - pos = pos * device.min_atom_distance / norm_factor |
295 | | - |
296 | | - # Finally, store the position. |
297 | | - self.pyg.pos = pos |
0 commit comments