Skip to content

Commit c819a79

Browse files
[TEST] Add some tests for remote extractors (#95)
* [TEST] Add some tests for remote extractors * [FEAT] Improve testing logic with fake progress and results * review * [DOC] Add some comments and readme * lint * review --------- Co-authored-by: Matthieu Moreau <matthieu.moreau@pasqal.com>
1 parent b31b9db commit c819a79

File tree

7 files changed

+304
-92
lines changed

7 files changed

+304
-92
lines changed

qek/data/extractors.py

Lines changed: 86 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
from typing import Any, Callable, Generator, Generic, Sequence, TypeVar, cast
1616
from numpy.typing import NDArray
1717
from pasqal_cloud import SDK
18-
from pasqal_cloud.batch import Batch
1918
from pasqal_cloud.device import BaseConfig, EmuTNConfig, EmulatorType
2019
from pasqal_cloud.job import Job
21-
from pasqal_cloud.utils.filters import BatchFilters
20+
from pasqal_cloud.utils.filters import JobFilters
2221
from pathlib import Path
2322
import numpy as np
2423
import os
@@ -561,7 +560,7 @@ class PasqalCloudExtracted(BaseExtracted):
561560
def __init__(
562561
self,
563562
compiled: list[Compiled],
564-
batch_ids: list[str],
563+
job_ids: list[str],
565564
sdk: SDK,
566565
state_extractor: Callable[[Job, pl.Sequence], dict[str, int] | None],
567566
path: Path | None = None,
@@ -571,13 +570,13 @@ def __init__(
571570
572571
Arguments:
573572
compiled: The result of compiling a set of graphs.
574-
batch_ids: The ids of the batches on the cloud API, in the same order as `compiled`.
573+
job_ids: The ids of the jobs on the cloud API, in the same order as `compiled`.
575574
state_extractor: A callback used to extract the counter from a job.
576575
Used as various cloud back-ends return different formats.
577576
path: If provided, a path at which to save the results once they're available.
578577
"""
579578
self._compiled = compiled
580-
self._batch_ids = batch_ids
579+
self._job_ids = job_ids
581580
self._results: SyncExtracted | None = None
582581
self._path = path
583582
self._sdk = sdk
@@ -592,28 +591,28 @@ def _wait(self) -> None:
592591
if self._results is not None:
593592
# Results are already available.
594593
return
595-
pending_batch_ids: set[str] = set(self._batch_ids)
596-
completed_batches: dict[str, Batch] = {}
597-
while len(pending_batch_ids) > 0:
594+
pending_job_ids: set[str] = set(self._job_ids)
595+
completed_jobs: dict[str, Job] = {}
596+
while len(pending_job_ids) > 0:
598597
time.sleep(SLEEP_DELAY_S)
599598

600-
# Fetch up to 100 pending batches (upstream limits).
601-
MAX_BATCH_LEN = 100
602-
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_batch_ids][
603-
:MAX_BATCH_LEN
599+
# Fetch up to 100 pending jobs (upstream limits).
600+
MAX_JOB_LEN = 100
601+
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_job_ids][
602+
:MAX_JOB_LEN
604603
]
605604

606605
# Update their status.
607-
check_batches = self._sdk.get_batches(filters=BatchFilters(id=check_ids))
608-
for batch in check_batches.results:
609-
assert isinstance(batch, Batch)
610-
if batch.status not in {"PENDING", "RUNNING"}:
611-
logger.debug("Job %s is now complete", batch.id)
612-
pending_batch_ids.discard(batch.id)
613-
completed_batches[batch.id] = batch
606+
check_jobs = self._sdk.get_jobs(filters=JobFilters(id=check_ids))
607+
for job in check_jobs.results:
608+
assert isinstance(job, Job)
609+
if job.status not in {"PENDING", "RUNNING"}:
610+
logger.debug("Job %s is now complete", job.id)
611+
pending_job_ids.discard(job.id)
612+
completed_jobs[job.id] = job
614613

615-
# At this point, all batches are complete.
616-
self._ingest(completed_batches)
614+
# At this point, all jobs are complete.
615+
self._ingest(completed_jobs)
617616

618617
def __await__(self) -> Generator[Any, Any, None]:
619618
"""
@@ -628,72 +627,69 @@ def __await__(self) -> Generator[Any, Any, None]:
628627
if self._results is not None:
629628
# Results are already available.
630629
return
631-
pending_batch_ids: set[str] = set(self._batch_ids)
632-
completed_batches: dict[str, Batch] = {}
633-
while len(pending_batch_ids) > 0:
630+
pending_job_ids: set[str] = set(self._job_ids)
631+
completed_jobs: dict[str, Job] = {}
632+
while len(pending_job_ids) > 0:
634633
yield from asyncio.sleep(SLEEP_DELAY_S).__await__()
635634

636-
# Fetch up to 100 pending batches (upstream limits).
637-
MAX_BATCH_LEN = 100
638-
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_batch_ids][
639-
:MAX_BATCH_LEN
635+
# Fetch up to 100 pending jobs (upstream limits).
636+
MAX_JOB_LEN = 100
637+
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_job_ids][
638+
:MAX_JOB_LEN
640639
]
641640

642641
# Update their status.
643-
check_batches = self._sdk.get_batches(
644-
filters=BatchFilters(id=check_ids)
642+
check_jobs = self._sdk.get_jobs(
643+
filters=JobFilters(id=check_ids)
645644
) # Ideally, this should be async, see https://github.com/pasqal-io/pasqal-cloud/issues/162.
646-
for batch in check_batches.results:
647-
assert isinstance(batch, Batch)
648-
if batch.status not in {"PENDING", "RUNNING"}:
649-
logger.debug("Job %s is now complete", batch.id)
650-
pending_batch_ids.discard(batch.id)
651-
completed_batches[batch.id] = batch
645+
for job in check_jobs.results:
646+
assert isinstance(job, Job)
647+
if job.status not in {"PENDING", "RUNNING"}:
648+
logger.debug("Job %s is now complete", job.id)
649+
pending_job_ids.discard(job.id)
650+
completed_jobs[job.id] = job
652651

653-
# At this point, all batches are complete.
654-
self._ingest(completed_batches)
652+
# At this point, all jobs are complete.
653+
self._ingest(completed_jobs)
655654

656-
def _ingest(self, batches: dict[str, Batch]) -> None:
655+
def _ingest(self, jobs: dict[str, Job]) -> None:
657656
"""
658657
Ingest data received from the remote server.
659658
660659
No I/O.
661660
"""
662-
assert len(batches) == len(self._batch_ids)
661+
assert len(jobs) == len(self._job_ids)
663662

664663
raw_data = []
665664
targets: list[int] = []
666665
sequences = []
667666
states = []
668-
for i, id in enumerate(self._batch_ids):
669-
batch = batches[id]
667+
for i, id in enumerate(self._job_ids):
668+
job = jobs[id]
670669
compiled = self._compiled[i]
671-
# Note: There's only one job per batch.
672-
assert len(batch.jobs) == 1
673-
for job in batch.jobs.values():
674-
if job.status == "DONE":
675-
state_dict = self._state_extractor(job, compiled.sequence)
676-
if state_dict is None:
677-
logger.warning(
678-
"Batch %s (graph %s) did not return a usable state, skipping",
679-
i,
680-
compiled.graph.id,
681-
)
682-
continue
683-
raw_data.append(compiled.graph)
684-
if compiled.graph.target is not None:
685-
targets.append(compiled.graph.target)
686-
sequences.append(compiled.sequence)
687-
states.append(state_dict)
688-
else:
689-
# If some sequences failed, let's skip them and proceed as well as we can.
670+
if job.status == "DONE":
671+
state_dict = self._state_extractor(job, compiled.sequence)
672+
if state_dict is None:
690673
logger.warning(
691-
"Batch %s (graph %s) failed with errors %s, skipping",
674+
"Job %s (graph %s) did not return a usable state, skipping",
692675
i,
693676
compiled.graph.id,
694-
job.status,
695-
job.errors,
696677
)
678+
continue
679+
raw_data.append(compiled.graph)
680+
if compiled.graph.target is not None:
681+
targets.append(compiled.graph.target)
682+
sequences.append(compiled.sequence)
683+
states.append(state_dict)
684+
else:
685+
# If some sequences failed, let's skip them and proceed as well as we can.
686+
logger.warning(
687+
"Job %s (graph %s) failed with status %s and errors %s, skipping",
688+
i,
689+
compiled.graph.id,
690+
job.status,
691+
job.errors,
692+
)
697693
self._results = SyncExtracted(
698694
raw_data=raw_data, targets=targets, sequences=sequences, states=states
699695
)
@@ -754,9 +750,9 @@ class BaseRemoteExtractor(BaseExtractor[GraphType], Generic[GraphType]):
754750
device_name: The name of the device to use. As of this writing,
755751
the default value of "FRESNEL" represents the latest QPU
756752
available through the Pasqal Cloud API.
757-
batch_id: Use this to resume a workflow e.g. after turning off
753+
job_id: Use this to resume a workflow e.g. after turning off
758754
your computer while the QPU was executing your sequences.
759-
Warning: A batch started with one executor MUST NOT be resumed
755+
Warning: A job started with one executor MUST NOT be resumed
760756
with a different executor.
761757
"""
762758

@@ -767,7 +763,7 @@ def __init__(
767763
username: str,
768764
device_name: str,
769765
password: str | None = None,
770-
batch_ids: list[str] | None = None,
766+
job_ids: list[str] | None = None,
771767
path: Path | None = None,
772768
):
773769
sdk = SDK(username=username, project_id=project_id, password=password)
@@ -778,11 +774,11 @@ def __init__(
778774

779775
super().__init__(device=device, compiler=compiler, path=path)
780776
self._sdk = sdk
781-
self._batch_ids: list[str] | None = batch_ids
777+
self._job_ids: list[str] | None = job_ids
782778

783779
@property
784-
def batch_ids(self) -> list[str] | None:
785-
return self._batch_ids
780+
def job_ids(self) -> list[str] | None:
781+
return self._job_ids
786782

787783
@abc.abstractmethod
788784
def run(
@@ -803,7 +799,7 @@ def _run(
803799
logger.warning("No sequences to run, did you forget to call compile()?")
804800
return PasqalCloudExtracted(
805801
compiled=[],
806-
batch_ids=[],
802+
job_ids=[],
807803
sdk=self._sdk,
808804
path=self.path,
809805
state_extractor=state_extractor,
@@ -814,34 +810,36 @@ def _run(
814810
# If we want to add more runs, we'll need to split them across several jobs.
815811
max_runs = device.max_runs if isinstance(device.max_runs, int) else 500
816812

817-
if self._batch_ids is None:
813+
if self._job_ids is None:
818814
# Enqueue jobs.
819-
self._batch_ids = []
815+
self._job_ids = []
820816
for compiled in self.sequences:
821817
logger.debug("Enqueuing execution of compiled graph #%s", compiled.graph.id)
822-
batch = self._sdk.create_batch(
818+
job = self._sdk.create_batch(
823819
compiled.sequence.to_abstract_repr(),
824820
jobs=[{"runs": max_runs}],
825821
wait=False,
826822
emulator=emulator,
827823
configuration=config,
828824
)
825+
assert len(job.ordered_jobs) == 1
826+
job_id = job.ordered_jobs[0].id
829827
logger.info(
830-
"Remote execution of compiled graph #%s starting, batched with id %s",
828+
"Remote execution of compiled graph #%s starting, job with id %s",
831829
compiled.graph.id,
832-
batch.id,
830+
job_id,
833831
)
834-
self._batch_ids.append(batch.id)
832+
self._job_ids.append(job_id)
835833
logger.info(
836834
"All %s jobs enqueued for remote execution, with ids %s",
837-
len(self._batch_ids),
838-
self._batch_ids,
835+
len(self._job_ids),
836+
self._job_ids,
839837
)
840-
assert len(self._batch_ids) == len(self.sequences)
838+
assert len(self._job_ids) == len(self.sequences)
841839

842840
return PasqalCloudExtracted(
843841
compiled=self.sequences,
844-
batch_ids=self._batch_ids,
842+
job_ids=self._job_ids,
845843
sdk=self._sdk,
846844
path=self.path,
847845
state_extractor=state_extractor,
@@ -876,7 +874,7 @@ class RemoteQPUExtractor(BaseRemoteExtractor[GraphType]):
876874
device_name: The name of the device to use. As of this writing,
877875
the default value of "FRESNEL" represents the latest QPU
878876
available through the Pasqal Cloud API.
879-
batch_id: Use this to resume a workflow e.g. after turning off
877+
job_id: Use this to resume a workflow e.g. after turning off
880878
your computer while the QPU was executing your sequences.
881879
"""
882880

@@ -887,7 +885,7 @@ def __init__(
887885
username: str,
888886
device_name: str = "FRESNEL",
889887
password: str | None = None,
890-
batch_ids: list[str] | None = None,
888+
job_ids: list[str] | None = None,
891889
path: Path | None = None,
892890
):
893891
super().__init__(
@@ -896,7 +894,7 @@ def __init__(
896894
username=username,
897895
device_name=device_name,
898896
password=password,
899-
batch_ids=batch_ids,
897+
job_ids=job_ids,
900898
path=path,
901899
)
902900

@@ -927,7 +925,7 @@ class RemoteEmuMPSExtractor(BaseRemoteExtractor[GraphType]):
927925
device_name: The name of the device to use. As of this writing,
928926
the default value of "FRESNEL" represents the latest QPU
929927
available through the Pasqal Cloud API.
930-
batch_id: Use this to resume a workflow e.g. after turning off
928+
job_id: Use this to resume a workflow e.g. after turning off
931929
your computer while the QPU was executing your sequences.
932930
"""
933931

@@ -938,7 +936,7 @@ def __init__(
938936
username: str,
939937
device_name: str = "FRESNEL",
940938
password: str | None = None,
941-
batch_ids: list[str] | None = None,
939+
job_ids: list[str] | None = None,
942940
path: Path | None = None,
943941
):
944942
super().__init__(
@@ -947,17 +945,16 @@ def __init__(
947945
username=username,
948946
device_name=device_name,
949947
password=password,
950-
batch_ids=batch_ids,
948+
job_ids=job_ids,
951949
path=path,
952950
)
953951

954952
def run(self, dt: int = 10) -> PasqalCloudExtracted:
955953
def extractor(job: Job, sequence: pl.Sequence) -> dict[str, int] | None:
956-
cutoff_duration = int(ceil(sequence.get_duration() / dt) * dt)
957954
full_result = job.full_result
958955
if full_result is None:
959956
return None
960-
result = full_result["bitstring"][cutoff_duration]
957+
result = full_result["counter"]
961958
if result is None:
962959
return None
963960
assert isinstance(result, dict)

qek/target/backends.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ async def run(
218218
bag = cast(dict[str, dict[int, Counter[str]]], job.result)
219219

220220
assert self._sequence is not None
221-
cutoff_duration = int(ceil(self._sequence.get_duration() / dt) * dt)
222-
return bag["bitstring"][cutoff_duration]
221+
return bag["counter"]
223222

224223

225224
if os.name == "posix":

tests/__init__.py

Whitespace-only changes.

tests/cloud_fixtures/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Fixtures
2+
3+
This folder contains some fixtures data that can be used to mock the cloud responses for tests.
4+
5+
## Files
6+
7+
- `device_specs.json`: Extract from the response of the GET /devices/public-specs endpoint. It can be used to retrieve the Fresnel specs.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"FRESNEL": "{\"name\": \"Fresnel\", \"dimensions\": 2, \"rydberg_level\": 60, \"min_atom_distance\": 5, \"max_atom_num\": 80, \"max_radial_distance\": 38, \"interaction_coeff_xy\": null, \"supports_slm_mask\": false, \"max_layout_filling\": 0.5, \"optimal_layout_filling\": 0.45, \"min_layout_traps\": 10, \"max_layout_traps\": 200, \"max_sequence_duration\": 6000, \"max_runs\": 500, \"reusable_channels\": false, \"pre_calibrated_layouts\": [], \"version\": \"1\", \"channels\": [{\"id\": \"rydberg_global\", \"basis\": \"ground-rydberg\", \"addressing\": \"Global\", \"max_abs_detuning\": 48.69468613064179, \"max_amp\": 12.566370614359172, \"min_retarget_interval\": null, \"fixed_retarget_t\": null, \"max_targets\": null, \"clock_period\": 4, \"min_duration\": 16, \"max_duration\": 6000, \"min_avg_amp\": 0.5654866776461628, \"mod_bandwidth\": 8, \"eom_config\": {\"limiting_beam\": \"RED\", \"max_limiting_amp\": 138.23007675795088, \"intermediate_detuning\": 2513.2741228718346, \"controlled_beams\": [\"BLUE\"], \"mod_bandwidth\": 40, \"custom_buffer_time\": 240, \"multiple_beam_control\": false, \"red_shift_coeff\": 1.656}}], \"is_virtual\": false}"
3+
}

0 commit comments

Comments
 (0)