Skip to content

Commit 4400491

Browse files
authored
[Tweak] Exposing the excitation histogram (#25)
1 parent c1fe208 commit 4400491

7 files changed

Lines changed: 335 additions & 145 deletions

File tree

examples/tutorial.ipynb

Lines changed: 87 additions & 47 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@ classifiers = [
2828
# always specify a version for each package
2929
# to maintain consistency
3030
dependencies = [
31-
"matplotlib",
3231
"networkx",
3332
"numpy",
3433
"pulser==1.1.1",
3534
"rdkit",
3635
"scikit-learn",
3736
"torch",
3837
"torch_geometric",
39-
"tqdm",
38+
"matplotlib",
4039
]
4140

4241
[tool.hatch.metadata]
@@ -46,6 +45,7 @@ allow-ambiguous-features = true
4645
[project.optional-dependencies]
4746
extras = [
4847
"jupyter",
48+
"tqdm",
4949
]
5050

5151
[project.urls]

qek/data/dataset.py

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3+
import collections
34
import json
4-
from dataclasses import dataclass
5+
import matplotlib
56

67
import numpy as np
78
import pulser as pl
89

910

10-
@dataclass
1111
class ProcessedData:
1212
"""
1313
Data on a single graph obtained from the Quantum Device.
@@ -37,10 +37,14 @@ class ProcessedData:
3737

3838
sequence: pl.Sequence
3939
state_dict: dict[str, int]
40+
_dist_excitation: np.ndarray
4041
target: int
4142

42-
def __post_init__(self) -> None:
43-
self.state_dict = _convert_np_int64_to_int(data=self.state_dict)
43+
def __init__(self, sequence: pl.Sequence, state_dict: dict[str, np.int64], target: int):
44+
self.sequence = sequence
45+
self.state_dict = _convert_np_int64_to_int(data=state_dict)
46+
self._dist_excitation = dist_excitation(self.state_dict)
47+
self.target = target
4448

4549
def save_to_file(self, file_path: str) -> None:
4650
with open(file_path, "w") as file:
@@ -61,6 +65,20 @@ def load_from_file(cls, file_path: str) -> "ProcessedData":
6165
target=tmp_data["target"],
6266
)
6367

68+
def dist_excitation(self, size: int | None = None) -> np.ndarray:
69+
"""
70+
Return the distribution of excitations for this graph.
71+
72+
Arguments:
73+
size: If specified, truncate or pad the array to this
74+
size.
75+
"""
76+
if size is None or size == len(self._dist_excitation):
77+
return self._dist_excitation.copy()
78+
if size < len(self._dist_excitation):
79+
return np.resize(self._dist_excitation, size)
80+
return np.pad(self._dist_excitation, (0, size - len(self._dist_excitation)))
81+
6482
def draw_sequence(self) -> None:
6583
"""
6684
Draw the sequence on screen
@@ -73,8 +91,119 @@ def draw_register(self) -> None:
7391
"""
7492
self.sequence.register.draw(blockade_radius=self.sequence.device.min_atom_distance + 0.01)
7593

94+
def draw_excitation(self) -> None:
95+
"""
96+
Draw an histogram for the excitation level on screen
97+
"""
98+
x = [str(i) for i in range(len(self._dist_excitation))]
99+
matplotlib.pyplot.bar(x, self._dist_excitation)
100+
101+
102+
def dist_excitation(state_dict: dict[str, int], size: int | None = None) -> np.ndarray:
103+
"""
104+
Calculates the distribution of excitation energies from a dictionary of
105+
bitstrings to their respective counts.
106+
107+
Args:
108+
size (int | None): If specified, only keep `size` energy
109+
distributions in the output. Otherwise, keep all values.
110+
111+
Returns:
112+
A histogram of excitation energies.
113+
- index: an excitation level (i.e. a number of `1` bits in a
114+
bitstring)
115+
- value: normalized count of samples with this excitation level.
116+
"""
117+
118+
if len(state_dict) == 0:
119+
return np.ndarray(0)
120+
121+
if size is None:
122+
# If size is not specified, it's the length of bitstrings.
123+
# We assume that all bitstrings in `count_bitstring` have the
124+
# same length and we have just checked that it's not empty.
125+
126+
# Pick the length of the first bitstring.
127+
# We have already checked that `count_bitstring` is not empty.
128+
bitstring = next(iter(state_dict.keys()))
129+
size = len(bitstring)
130+
131+
# Make mypy realize that `size` is now always an `int`.
132+
assert type(size) is int
133+
134+
count_occupation: dict[int, int] = collections.defaultdict(int)
135+
total = 0.0
136+
for bitstring, number in state_dict.items():
137+
occupation = sum(1 for bit in bitstring if bit == "1")
138+
count_occupation[occupation] += number
139+
total += number
140+
141+
result = np.zeros(size + 1, dtype=float)
142+
for occupation, count in count_occupation.items():
143+
if occupation < size:
144+
result[occupation] = count / total
145+
146+
return result
147+
76148

77149
def _convert_np_int64_to_int(data: dict[str, np.int64]) -> dict[str, int]:
150+
"""
151+
Utility function: convert the values of a dict from `np.int64` to `int`,
152+
for serialization purposes.
153+
"""
78154
return {
79155
key: (int(value) if isinstance(value, np.integer) else value) for key, value in data.items()
80156
}
157+
158+
159+
def save_dataset(dataset: list[ProcessedData], file_path: str) -> None:
160+
"""Saves a dataset to a JSON file.
161+
162+
Args:
163+
dataset (list[ProcessedData]): The dataset to be saved, containing
164+
RegisterData instances.
165+
file_path (str): The path where the dataset will be saved as a JSON
166+
file.
167+
168+
Note:
169+
The data is stored in a format suitable for loading with load_dataset.
170+
171+
Returns:
172+
None
173+
"""
174+
with open(file_path, "w") as file:
175+
data = [
176+
{
177+
"sequence": instance.sequence.to_abstract_repr(),
178+
"state_dict": instance.state_dict,
179+
"target": instance.target,
180+
}
181+
for instance in dataset
182+
]
183+
json.dump(data, file)
184+
185+
186+
def load_dataset(file_path: str) -> list[ProcessedData]:
187+
"""Loads a dataset from a JSON file.
188+
189+
Args:
190+
file_path (str): The path to the JSON file containing the dataset.
191+
192+
Note:
193+
The data is loaded in the format that was used when saving with
194+
save_dataset.
195+
196+
Returns:
197+
A list of ProcessedData instances, corresponding to the data stored in
198+
the JSON file.
199+
"""
200+
with open(file_path) as file:
201+
data = json.load(file)
202+
return [
203+
ProcessedData(
204+
sequence=pl.Sequence.from_abstract_repr(item["sequence"]),
205+
state_dict=item["state_dict"],
206+
target=item["target"],
207+
)
208+
for item in data
209+
]

qek/data/datatools.py

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
from typing import Final
54

65
import networkx as nx
@@ -13,7 +12,6 @@
1312
import torch_geometric.utils as pyg_utils
1413
from rdkit.Chem import AllChem
1514

16-
from qek.data.dataset import ProcessedData
1715
from qek.utils import graph_to_mol
1816

1917

@@ -44,59 +42,6 @@ def split_train_test(
4442
return train, val
4543

4644

47-
def save_dataset(dataset: list[ProcessedData], file_path: str) -> None:
48-
"""Saves a dataset to a JSON file.
49-
50-
Args:
51-
dataset (list[ProcessedData]): The dataset to be saved, containing
52-
RegisterData instances.
53-
file_path (str): The path where the dataset will be saved as a JSON
54-
file.
55-
56-
Note:
57-
The data is stored in a format suitable for loading with load_dataset.
58-
59-
Returns:
60-
None
61-
"""
62-
with open(file_path, "w") as file:
63-
data = [
64-
{
65-
"sequence": instance.sequence.to_abstract_repr(),
66-
"state_dict": instance.state_dict,
67-
"target": instance.target,
68-
}
69-
for instance in dataset
70-
]
71-
json.dump(data, file)
72-
73-
74-
def load_dataset(file_path: str) -> list[ProcessedData]:
75-
"""Loads a dataset from a JSON file.
76-
77-
Args:
78-
file_path (str): The path to the JSON file containing the dataset.
79-
80-
Note:
81-
The data is loaded in the format that was used when saving with
82-
save_dataset.
83-
84-
Returns:
85-
A list of ProcessedData instances, corresponding to the data stored in
86-
the JSON file.
87-
"""
88-
with open(file_path) as file:
89-
data = json.load(file)
90-
return [
91-
ProcessedData(
92-
sequence=pl.Sequence.from_abstract_repr(item["sequence"]),
93-
state_dict=item["state_dict"],
94-
target=item["target"],
95-
)
96-
for item in data
97-
]
98-
99-
10045
EPSILON_DISTANCE_UM = 0.01
10146

10247

@@ -157,8 +102,10 @@ def is_disk_graph(self, radius: float) -> bool:
157102
if len(self.nx_graph) == 0 or not nx.is_connected(self.nx_graph):
158103
return False
159104

160-
# Check the distances between all pairs of nodes.
161105
pos = self.pyg.pos
106+
assert pos is not None
107+
108+
# Check the distances between all pairs of nodes.
162109
for u, v in nx.non_edges(self.nx_graph):
163110
distance_um = np.linalg.norm(np.array(pos[u]) - np.array(pos[v]))
164111
if distance_um <= radius:
@@ -204,7 +151,9 @@ def is_embeddable(self) -> bool:
204151
return False
205152

206153
# Check the distance from the center
154+
207155
pos = self.pyg.pos
156+
assert pos is not None
208157
distance_from_center = np.linalg.norm(pos, ord=2, axis=-1)
209158
if any(distance_from_center > self.device.max_radial_distance):
210159
return False
@@ -228,7 +177,9 @@ def compute_register(self) -> pl.Register:
228177
Returns:
229178
pulser.Register: register
230179
"""
231-
return pl.Register.from_coordinates(coords=self.pyg.pos)
180+
pos = self.pyg.pos
181+
assert pos is not None
182+
return pl.Register.from_coordinates(coords=pos)
232183

233184
def compute_sequence(self) -> pl.Sequence:
234185
"""

0 commit comments

Comments
 (0)