Skip to content

Commit 9455daa

Browse files
authored
[API] Rework the return of a compiler to avoid async and make target optional. (#36)
We don't want our users to need to understand `async`/`await` until they need to, so we rework the output of extractors to make them usable without `async`/`await`. Only users who write code for servers or interactive applications will need to `await`. Also, we make `target` optional, because we want to be able to use our kernel with data that does not come from the training material! Note: breaking API changes, we'll need to bump version.
1 parent bbb881b commit 9455daa

7 files changed

Lines changed: 634 additions & 238 deletions

File tree

examples/tutorial.ipynb

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"source": [
5555
"To extract machine-learning features from our dataset, we will need to configure a feature extractor. This library provides several feature extractors to either make use of a physical quantum device (QPU), or a variety of emulators.\n",
5656
"\n",
57-
"To configure a feature extractor, we will need to give it a _compiler_, whose task is to take a list of graphs, extract embeddings and compile these embeddings to _sequences of pulses_, the format that can be executed by either a QPU or an emulator. For this tutorial, our dataset is composed of molecule graphs, so we will use the `MoleculeGraphCompiler`:"
57+
"To configure a feature extractor, we will need to give it a _compiler_, whose task is to take a list of graphs, extract embeddings and compile these embeddings to _sequences of pulses_, the format that can be executed by either a QPU or an emulator. For this tutorial, our dataset is composed of molecule graphs encoded with the PTC-FM conventions, so we will use the `PTCFMGraphCompiler`:"
5858
]
5959
},
6060
{
@@ -65,14 +65,14 @@
6565
"source": [
6666
"import qek.data.graphs as qek_graphs\n",
6767
"\n",
68-
"compiler = qek_graphs.MoleculeGraphCompiler()"
68+
"compiler = qek_graphs.PTCFMCompiler()"
6969
]
7070
},
7171
{
7272
"cell_type": "markdown",
7373
"metadata": {},
7474
"source": [
75-
"This library provides other compilers from other formats of graphs."
75+
"This library provides other compilers from other formats of graphs, including the `MoleculeGraphCompiler` and general-purpose graph compilers for pytorch_geometric or networkx graphs."
7676
]
7777
},
7878
{
@@ -126,8 +126,8 @@
126126
"# You can increase this value to higher number of qubits, but this\n",
127127
"# notebook will take longer to execute and may run out of memory.\n",
128128
"max_qubits = 5\n",
129-
"processed_dataset = await extractor.run(max_qubits=max_qubits) # Don't forget to `await`!\n",
130-
"display(\"Extracted features from %s samples\"% (len(processed_dataset), ))"
129+
"processed_dataset = extractor.run(max_qubits=max_qubits)\n",
130+
"display(\"Extracted features from %s samples\"% (len(processed_dataset.states), ))"
131131
]
132132
},
133133
{
@@ -159,8 +159,6 @@
159159
"HAVE_PASQAL_ACCOUNT = False # If you have a PASQAL Cloud account, fill in the details and set this to `True`.\n",
160160
"\n",
161161
"if HAVE_PASQAL_ACCOUNT:\n",
162-
" processed_dataset = []\n",
163-
"\n",
164162
" # Use the QPU Extractor.\n",
165163
" extractor = qek_extractors.QPUExtractor(\n",
166164
" # Once computing is complete, data will be saved in this file.\n",
@@ -179,12 +177,12 @@
179177
" display(\"Compiled %s sequences\" % (len(compiled), ))\n",
180178
"\n",
181179
" # Launch the execution.\n",
182-
" execution = extractor.run()\n",
183-
" display(\"Work enqueued with ids %s\" % (extractor.batch_ids, ))\n",
180+
" processed_dataset = extractor.run()\n",
181+
" display(\"Work enqueued with ids %s\" % (processed_dataset.batch_ids, ))\n",
184182
"\n",
185183
" # ...and wait for the results.\n",
186-
" processed_dataset = await execution\n",
187-
" display(\"Extracted features from %s samples\"% (len(processed_dataset), ))"
184+
" await processed_dataset\n",
185+
" display(\"Extracted states from %s samples\"% (len(processed_dataset.states), ))"
188186
]
189187
},
190188
{
@@ -221,7 +219,8 @@
221219
"outputs": [],
222220
"source": [
223221
"import qek.data.dataset as qek_dataset\n",
224-
"processed_dataset = qek_dataset.load_dataset(file_path=\"ptcfm_processed_dataset.json\")\n",
222+
"from qek.data.dataset import ProcessedData\n",
223+
"processed_dataset: list[ProcessedData] = qek_dataset.load_dataset(file_path=\"ptcfm_processed_dataset.json\")\n",
225224
"print(f\"Size of the quantum compatible dataset = {len(processed_dataset)}\")"
226225
]
227226
},

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies = [
3737
"torch",
3838
"torch_geometric",
3939
"matplotlib",
40-
"emu-mps",
40+
"emu-mps~=1.2.0",
4141
]
4242

4343
[tool.hatch.metadata]

qek/data/dataset.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import collections
22
import json
3-
from typing import cast
3+
from typing import Final, cast
44
import matplotlib
55

66
import logging
@@ -21,8 +21,8 @@ class ProcessedData:
2121
executed on the device.
2222
state_dict: A dictionary {bitstring: number of instances}
2323
for this graph.
24-
target: The machine-learning target (in this case, a value
25-
in {0, 1}, as specified by the original graph).
24+
target: If specified, the machine-learning target, as a
25+
value `0` or `1`.
2626
2727
The state dictionary represents an approximation of the quantum
2828
state of the device for this graph after completion of the
@@ -40,14 +40,24 @@ class ProcessedData:
4040
specific graph).
4141
"""
4242

43-
sequence: pl.Sequence
44-
state_dict: dict[str, int]
45-
_dist_excitation: np.ndarray
46-
target: int
43+
sequence: Final[pl.Sequence]
44+
state_dict: Final[dict[str, int]]
45+
_dist_excitation: Final[np.ndarray]
46+
target: Final[int | None]
4747

48-
def __init__(self, sequence: pl.Sequence, state_dict: dict[str, np.int64], target: int):
48+
def __init__(
49+
self, sequence: pl.Sequence, state_dict: dict[str, int | np.int64], target: int | None
50+
):
4951
self.sequence = sequence
50-
self.state_dict = _convert_np_int64_to_int(data=state_dict)
52+
# Some emulators will actually be `dict[str, int64]` instead of `dict[str, int]` and `int64`
53+
# is not JSON-serializable.
54+
#
55+
# The reason for which `int64` is not JSON-serializable is that JSON limits ints to 2^53-1.
56+
# In practice, this should not be a problem, since the `int`/`int64` in our dict is
57+
# limited to the number of runs, and we don't expect to be launching 2^53 consecutive runs
58+
# for a single sequence on a device in any foreseeable future (assuming a run of 1ns,
59+
# this would still take ~4 billion years to execute).
60+
self.state_dict = {k: int(value) for k, value in state_dict.items()}
5161
self._dist_excitation = dist_excitation(self.state_dict)
5262
self.target = target
5363

@@ -156,16 +166,6 @@ def dist_excitation(state_dict: dict[str, int], size: int | None = None) -> np.n
156166
return result
157167

158168

159-
def _convert_np_int64_to_int(data: dict[str, np.int64]) -> dict[str, int]:
160-
"""
161-
Utility function: convert the values of a dict from `np.int64` to `int`,
162-
for serialization purposes.
163-
"""
164-
return {
165-
key: (int(value) if isinstance(value, np.integer) else value) for key, value in data.items()
166-
}
167-
168-
169169
def save_dataset(dataset: list[ProcessedData], file_path: str) -> None:
170170
"""Saves a dataset to a JSON file.
171171

0 commit comments

Comments
 (0)