Skip to content

Commit b11a92f

Browse files
authored
[FIX] A few fixes to QPUExtractor (#34)
1 parent 45a3677 commit b11a92f

2 files changed

Lines changed: 46 additions & 32 deletions

File tree

examples/tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@
398398
"metadata": {},
399399
"outputs": [],
400400
"source": [
401-
"train_kernel = kernel.create_train_kernel_matrix(processed_dataset)\n",
401+
"train_kernel = kernel.fit_transform(processed_dataset)\n",
402402
"y_tot = [data.target for data in processed_dataset]"
403403
]
404404
},

qek/data/extractors.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import logging
66
from math import ceil
77
from typing import Any, Callable, Generic, Sequence, TypeVar, cast
8+
from uuid import UUID
89
import emu_mps
9-
from pasqal_cloud import SDK
10-
from pasqal_cloud.batch import Batch
10+
from pasqal_cloud import SDK, Batch, BatchFilters
1111
import pulser as pl
1212
from pulser.devices import Device
1313
from pulser.json.abstract_repr.deserializer import deserialize_device
@@ -341,7 +341,7 @@ def __init__(
341341
username: str,
342342
password: str | None = None,
343343
device_name: str = "FRESNEL",
344-
batch_id: list[str] | None = None,
344+
batch_ids: list[str] | None = None,
345345
):
346346
sdk = SDK(username=username, project_id=project_id, password=password)
347347

@@ -351,65 +351,79 @@ def __init__(
351351

352352
super().__init__(path=path, device=device, compiler=compiler)
353353
self._sdk = sdk
354-
self._batch_id = batch_id
354+
self._batch_ids: list[str] | None = batch_ids
355355

356356
@property
357357
def batch_ids(self) -> list[str] | None:
358-
return self._batch_id
358+
return self._batch_ids
359359

360360
async def run(self) -> list[ProcessedData]:
361361
if len(self.sequences) == 0:
362362
logger.warning("No sequences to run, did you forget to call compile()?")
363363
return []
364364

365-
batches: list[Batch] = []
366-
if self._batch_id is None:
365+
device: pl.devices.Device = self.sequences[0].sequence.device
366+
# The API doesn't support run longer than 500 jobs.
367+
# If we want to add more runs, we'll need to split them across several jobs.
368+
max_runs = device.max_runs if isinstance(device.max_runs, int) else 500
369+
370+
if self._batch_ids is None:
367371
# Enqueue jobs.
368372
self._batch_ids = []
369373
for compiled in self.sequences:
370-
logger.debug("Executing compiled graph #%s", id)
374+
logger.debug("Enqueuing execution of compiled graph #%s", compiled.graph.id)
371375
batch = self._sdk.create_batch(
372376
compiled.sequence.to_abstract_repr(),
373-
jobs=[{"runs": 1000}],
377+
jobs=[{"runs": max_runs}],
374378
wait=False,
375379
)
376380
logger.info(
377381
"Remote execution of compiled graph #%s starting, batched with id %s",
378-
id,
382+
compiled.graph.id,
379383
batch.id,
380384
)
381-
batches.append(batch)
382385
self._batch_ids.append(batch.id)
383386
logger.info(
384387
"All %s jobs enqueued for remote execution, with ids %s",
385-
len(batches),
388+
len(self._batch_ids),
386389
self._batch_ids,
387390
)
388-
else:
389-
# Get jobs back from the cloud API.
390-
for batch_id in self._batch_id:
391-
batches.append(self._sdk.get_batch(batch_id))
391+
assert len(self._batch_ids) == len(self.sequences)
392392

393393
# Now wait until all batches are complete.
394-
waiting = True
395-
while waiting:
396-
waiting = False
397-
for batch in batches:
398-
if batch.status in {"PENDING", "RUNNING"}:
399-
# At least one job is pending, let's wait.
400-
await sleep(2)
401-
logger.debug("Job %s is still incomplete")
402-
waiting = True
403-
404-
logger.info("All jobs complete, %s sequences executed", len(batches))
394+
pending_batch_ids: set[str] = set(self._batch_ids)
395+
completed_batches: dict[str, Batch] = {}
396+
397+
while len(pending_batch_ids) > 0:
398+
await sleep(delay=2)
399+
# We can check up to 100 batches in a single query with the SDK, so let's do that.
400+
MAX_BATCH_LEN = 100
401+
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_batch_ids][
402+
:MAX_BATCH_LEN
403+
]
404+
# Update their status.
405+
check_batches = self._sdk.get_batches(
406+
filters=BatchFilters(id=check_ids)
407+
) # Ideally, this should be async, see https://github.com/pasqal-io/pasqal-cloud/issues/162.
408+
for batch in check_batches.results:
409+
assert isinstance(batch, Batch)
410+
if batch.status not in {"PENDING", "RUNNING"}:
411+
logger.debug("Job %s is now complete", batch.id)
412+
pending_batch_ids.discard(batch.id)
413+
completed_batches[batch.id] = batch
414+
415+
logger.info("All jobs complete, %s sequences executed", len(completed_batches))
405416

406417
# At this point, all batches are complete.
407-
# Now collect data. We rely upon the fact
408-
# that we enqueued exactly one batch per sequence, in the same order.
418+
# Now, collect data.
419+
#
420+
# We rely upon the fact that for any `i`,
421+
# `self._batch_id[i]` is the batch for `self.sequences[i]`.
409422
processed_data: list[ProcessedData] = []
410-
for i, batch in enumerate(batches):
411-
# Note: There's only one job per batch.
423+
for i, id in enumerate(self._batch_ids):
424+
batch = completed_batches[id]
412425
assert len(batch.jobs) == 1
426+
413427
for _, job in batch.jobs.items():
414428
if job.status == "DONE":
415429
state_dict = job.result

0 commit comments

Comments
 (0)