Skip to content

Commit 405b50b

Browse files
authored
[Refactor] Expose a higher-level API (#30)
We expect that our users will be more interested in machine-learning than in quantum computers, at least at first, so we expose a new, higher-level, API, that hides most of the quantum details. With this API, switching between the QutipEmulator, emu-mps or QPU is just a few lines of code (well, one line of code + the connection details, username, password, project id). This results in a tutorial that spends less time on the quantum aspects and more on the machine-learning. Also, more tests.
1 parent 4400491 commit 405b50b

9 files changed

Lines changed: 1258 additions & 496 deletions

File tree

examples/tutorial.ipynb

Lines changed: 186 additions & 209 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
"torch",
3737
"torch_geometric",
3838
"matplotlib",
39+
"emu-mps",
3940
]
4041

4142
[tool.hatch.metadata]
@@ -59,6 +60,7 @@ dependencies = [
5960
"pytest",
6061
"pytest-cov",
6162
"pytest-xdist",
63+
"pytest-asyncio",
6264
"nbconvert",
6365
"ipykernel",
6466
"pre-commit",
@@ -84,6 +86,8 @@ filterwarnings = [
8486
"ignore:Call to deprecated create function OneofDescriptor",
8587
"ignore:distutils Version classes are deprecated.",
8688
]
89+
asyncio_mode="auto"
90+
asyncio_default_fixture_loop_scope="function"
8791

8892
[tool.hatch.envs.docs]
8993
dependencies = [

qek/data/dataset.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from __future__ import annotations
2-
31
import collections
42
import json
3+
from typing import cast
54
import matplotlib
65

6+
import logging
77
import numpy as np
88
import pulser as pl
99

10+
from qek.data.graphs import EPSILON_RADIUS_UM
11+
12+
logger = logging.getLogger(__name__)
13+
1014

1115
class ProcessedData:
1216
"""
@@ -17,7 +21,8 @@ class ProcessedData:
1721
executed on the device.
1822
state_dict: A dictionary {bitstring: number of instances}
1923
for this graph.
20-
target: The target, i.e. the identifier of the graph.
24+
target: The machine-learning target (in this case, a value
25+
in {0, 1}, as specified by the original graph).
2126
2227
The state dictionary represents an approximation of the quantum
2328
state of the device for this graph after completion of the
@@ -89,7 +94,12 @@ def draw_register(self) -> None:
8994
"""
9095
Draw the register on screen
9196
"""
92-
self.sequence.register.draw(blockade_radius=self.sequence.device.min_atom_distance + 0.01)
97+
register = cast(pl.Register, self.sequence.register)
98+
register.draw(
99+
# We increase slightly the blockade radius to take into account rounding errors.
100+
blockade_radius=self.sequence.device.min_atom_distance
101+
+ EPSILON_RADIUS_UM
102+
)
93103

94104
def draw_excitation(self) -> None:
95105
"""

qek/data/datatools.py

Lines changed: 0 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,5 @@
1-
from __future__ import annotations
2-
3-
from typing import Final
4-
5-
import networkx as nx
6-
import numpy as np
7-
import pulser as pl
8-
import rdkit.Chem as Chem
91
import torch
102
import torch.utils.data as torch_data
11-
import torch_geometric.data as pyg_data
12-
import torch_geometric.utils as pyg_utils
13-
from rdkit.Chem import AllChem
14-
15-
from qek.utils import graph_to_mol
163

174

185
def split_train_test(
@@ -40,258 +27,3 @@ def split_train_test(
4027
generator = torch.Generator()
4128
train, val = torch_data.random_split(dataset=dataset, lengths=lengths, generator=generator)
4229
return train, val
43-
44-
45-
EPSILON_DISTANCE_UM = 0.01
46-
47-
48-
class BaseGraph:
49-
"""
50-
A graph being prepared for embedding on a quantum device.
51-
"""
52-
53-
device: Final[pl.devices.Device]
54-
55-
def __init__(self, data: pyg_data.Data, device: pl.devices.Device):
56-
"""
57-
Create a graph from geometric data.
58-
59-
Args:
60-
data: A homogeneous graph, in PyTorch Geometric format. Unchanged.
61-
It MUST have attributes 'pos'
62-
device: The device for which the graph is prepared.
63-
"""
64-
if not hasattr(data, "pos"):
65-
raise AttributeError("The graph should have an attribute 'pos'.")
66-
67-
# The device for which the graph is prepared.
68-
self.device = device
69-
70-
# The graph in torch geometric format.
71-
self.pyg = data.clone()
72-
73-
# The graph in networkx format, undirected.
74-
self.nx_graph = pyg_utils.to_networkx(
75-
data=data,
76-
node_attrs=["x"],
77-
edge_attrs=["edge_attr"] if data.edge_attr is not None else None,
78-
to_undirected=True,
79-
)
80-
81-
def is_disk_graph(self, radius: float) -> bool:
82-
"""
83-
A predicate to check if `self` is a disk graph with the specified
84-
radius, i.e. `self` is a connected graph and, for every pair of nodes
85-
`A` and `B` within `graph`, there exists an edge between `A` and `B`
86-
if and only if the positions of `A` and `B` within `self` are such
87-
that `|AB| <= radius`.
88-
89-
Args:
90-
radius: The maximal distance between two nodes of `self`
91-
connected be an edge.
92-
93-
Returns:
94-
`True` if the graph is a disk graph with the specified radius,
95-
`False` otherwise.
96-
"""
97-
98-
if self.pyg.num_nodes == 0 or self.pyg.num_nodes is None:
99-
return False
100-
101-
# Check if the graph is connected.
102-
if len(self.nx_graph) == 0 or not nx.is_connected(self.nx_graph):
103-
return False
104-
105-
pos = self.pyg.pos
106-
assert pos is not None
107-
108-
# Check the distances between all pairs of nodes.
109-
for u, v in nx.non_edges(self.nx_graph):
110-
distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v]))
111-
if distance_um <= radius:
112-
# These disjointed nodes would interact with each other, so
113-
# this is not an embeddable graph.
114-
return False
115-
116-
for u, v in self.nx_graph.edges():
117-
distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v]))
118-
if distance_um > radius:
119-
# These joined nodes would not interact with each other, so
120-
# this is not an embeddable graph.
121-
return False
122-
123-
return True
124-
125-
def is_embeddable(self) -> bool:
126-
"""
127-
A predicate to check if the graph can be embedded in the
128-
quantum device.
129-
130-
For a graph to be embeddable on a device, all the following
131-
criteria must be fulfilled:
132-
- the graph must be non-empty;
133-
- the device must have at least as many atoms as the graph has
134-
nodes;
135-
- the device must be physically large enough to place all the
136-
nodes (device.max_radial_distance);
137-
- the nodes must be distant enough that quantum interactions
138-
may take place (device.min_atom_distance)
139-
140-
Returns:
141-
bool: True if possible, False if not
142-
"""
143-
144-
# Reject empty graphs.
145-
if self.pyg.num_nodes == 0 or self.pyg.num_nodes is None:
146-
return False
147-
148-
# Reject graphs that have more nodes than can be represented
149-
# on the device.
150-
if self.pyg.num_nodes > self.device.max_atom_num:
151-
return False
152-
153-
# Check the distance from the center
154-
155-
pos = self.pyg.pos
156-
assert pos is not None
157-
distance_from_center = np.linalg.norm(pos, ord=2, axis=-1)
158-
if any(distance_from_center > self.device.max_radial_distance):
159-
return False
160-
161-
# Check distance between nodes
162-
if not self.is_disk_graph(self.device.min_atom_distance + EPSILON_DISTANCE_UM):
163-
return False
164-
165-
for u, v in self.nx_graph.edges():
166-
distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v]))
167-
if distance_um < self.device.min_atom_distance:
168-
# These nodes are too close to each other, preventing quantum
169-
# interactions on the device.
170-
return False
171-
172-
return True
173-
174-
def compute_register(self) -> pl.Register:
175-
"""Create a Quantum Register based on a graph.
176-
177-
Returns:
178-
pulser.Register: register
179-
"""
180-
pos = self.pyg.pos
181-
assert pos is not None
182-
return pl.Register.from_coordinates(coords=pos)
183-
184-
def compute_sequence(self) -> pl.Sequence:
185-
"""
186-
Compile a Quantum Sequence from a graph for a specific device.
187-
188-
Raises:
189-
ValueError if the graph cannot be embedded on the given device.
190-
"""
191-
if not self.is_embeddable():
192-
raise ValueError(f"The graph is not compatible with {self.device}")
193-
reg = self.compute_register()
194-
if self.device.requires_layout:
195-
reg = reg.with_automatic_layout(device=self.device)
196-
197-
seq = pl.Sequence(register=reg, device=self.device)
198-
199-
# See the companion paper for an explanation on these constants.
200-
Omega_max = 1.0 * 2 * np.pi
201-
t_max = 660
202-
pulse = pl.Pulse.ConstantAmplitude(
203-
amplitude=Omega_max,
204-
detuning=pl.waveforms.RampWaveform(t_max, 0, 0),
205-
phase=0.0,
206-
)
207-
seq.declare_channel("ising", "rydberg_global")
208-
seq.add(pulse, "ising")
209-
return seq
210-
211-
212-
class MoleculeGraph(BaseGraph):
213-
"""
214-
A graph based on molecular data, being prepared for embedding on a
215-
quantum device.
216-
"""
217-
218-
# Constants used to decode the PTC-FM dataset, mapping
219-
# integers (used as node attributes) to atom names.
220-
PTCFM_ATOM_NAMES: Final[dict[int, str]] = {
221-
0: "In",
222-
1: "P",
223-
2: "C",
224-
3: "O",
225-
4: "N",
226-
5: "Cl",
227-
6: "S",
228-
7: "Br",
229-
8: "Na",
230-
9: "F",
231-
10: "As",
232-
11: "K",
233-
12: "Cu",
234-
13: "I",
235-
14: "Ba",
236-
15: "Sn",
237-
16: "Pb",
238-
17: "Ca",
239-
}
240-
241-
# Constants used to decode the PTC-FM dataset, mapping
242-
# integers (used as edge attributes) to bond types.
243-
PTCFM_BOND_TYPES: Final[dict[int, Chem.BondType]] = {
244-
0: Chem.BondType.TRIPLE,
245-
1: Chem.BondType.SINGLE,
246-
2: Chem.BondType.DOUBLE,
247-
3: Chem.BondType.AROMATIC,
248-
}
249-
250-
def __init__(
251-
self,
252-
data: pyg_data.Data,
253-
device: pl.devices.Device,
254-
node_mapping: dict[int, str] = PTCFM_ATOM_NAMES,
255-
edge_mapping: dict[int, Chem.BondType] = PTCFM_BOND_TYPES,
256-
):
257-
"""
258-
Compute the geometry for a molecule graph.
259-
260-
Args:
261-
data: A homogeneous graph, in PyTorch Geometric format. Unchanged.
262-
blockade_radius: The radius of the Rydberg Blockade. Two
263-
connected nodes should be at a distance < blockade_radius,
264-
while two disconnected nodes should be at a
265-
distance > blockade_radius.
266-
node_mapping: A mapping of node labels from numbers to strings,
267-
e.g. `5 => "Cl"`. Used when building molecules, e.g. to compute
268-
distances between nodes.
269-
edge_mapping: A mapping of edge labels from number to chemical
270-
bond types, e.g. `2 => DOUBLE`. Used when building molecules,
271-
e.g. to compute distances between nodes.
272-
"""
273-
pyg = data.clone()
274-
pyg.pos = None # Placeholder
275-
super().__init__(pyg, device)
276-
277-
# Reconstruct the molecule.
278-
tmp_mol = graph_to_mol(
279-
graph=self.nx_graph,
280-
node_mapping=node_mapping,
281-
edge_mapping=edge_mapping,
282-
)
283-
284-
# Extract the geometry.
285-
AllChem.Compute2DCoords(tmp_mol, useRingTemplates=True)
286-
pos = tmp_mol.GetConformer().GetPositions()[..., :2] # Convert to 2D
287-
288-
# Scale the geometry so that the longest edge is as long as
289-
# `device.min_atom_distance`.
290-
dist_list = []
291-
for start, end in self.nx_graph.edges():
292-
dist_list.append(np.linalg.norm(pos[start] - pos[end]))
293-
norm_factor = np.max(dist_list)
294-
pos = pos * device.min_atom_distance / norm_factor
295-
296-
# Finally, store the position.
297-
self.pyg.pos = pos

0 commit comments

Comments
 (0)