Skip to content

Commit 791f3aa

Browse files
authored
[Testing] Let's test the local backends (#60)
1 parent 2603e34 commit 791f3aa

4 files changed

Lines changed: 92 additions & 1 deletion

File tree

tests/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from time import sleep
2+
import torch_geometric.datasets as pyg_dataset
3+
4+
5+
def preload_dataset() -> None:
6+
# Attempt to force-download the PTC_FM dataset, which has difficulties on some
7+
# platforms. We suspect that it's a race condition somewhere in pytorch geometric.
8+
#
9+
# Sometimes, this raises FileNotFoundError, sometimes RuntimeError, reinforcing
10+
# the suggestion that something is happening in the background.
11+
exn: FileNotFoundError | RuntimeError | None = None
12+
for i in range(0, 10):
13+
sleep(i * i)
14+
try:
15+
print(f"Attempt {i+1} to download dataset")
16+
pyg_dataset.TUDataset(root="dataset", name="PTC_FM")
17+
print(f"Attempt {i+1} to download dataset succeeded")
18+
exn = None
19+
break
20+
except (FileNotFoundError, RuntimeError) as e:
21+
print(f"Attempt {i+1} to download failed: {e}")
22+
exn = e
23+
if exn is not None:
24+
raise exn

tests/test_backends.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import cast
2+
3+
import os
4+
import pulser as pl
5+
import pytest
6+
import conftest
7+
import torch_geometric.data as pyg_data
8+
import torch_geometric.datasets as pyg_dataset
9+
from qek.backends import CompilationError, QutipBackend, BaseBackend
10+
import qek.data.graphs as qek_graphs
11+
12+
if os.name == "posix":
13+
# As of this writing, emu-mps only works under Unix.
14+
from qek.backends import EmuMPSBackend
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_async_emulators() -> None:
19+
"""
20+
Test that backends based on emulators can execute without exploding (async).
21+
"""
22+
conftest.preload_dataset()
23+
24+
# Load dataset
25+
original_ptcfm_data = [
26+
cast(pyg_data.Data, d) for d in pyg_dataset.TUDataset(root="dataset", name="PTC_FM")
27+
]
28+
29+
compiled: list[tuple[qek_graphs.BaseGraph, pl.Register, pl.Pulse]] = []
30+
for i, data in enumerate(original_ptcfm_data):
31+
graph = qek_graphs.PTCFMGraph(data=data, device=pl.AnalogDevice, id=i)
32+
try:
33+
register = graph.compile_register()
34+
pulse = graph.compile_pulse()
35+
if len(register.qubits) >= 5:
36+
# This will be too slow to execute, skip.
37+
continue
38+
except CompilationError:
39+
# Let's just skip graphs that cannot be computed.
40+
continue
41+
compiled.append((graph, register, pulse))
42+
if len(compiled) >= 5:
43+
# We only need a few samples.
44+
break
45+
46+
assert len(compiled) >= 5
47+
48+
backends: list[BaseBackend] = [QutipBackend(pl.AnalogDevice)]
49+
if os.name == "posix":
50+
backends.append(EmuMPSBackend(pl.AnalogDevice))
51+
52+
for backend in backends:
53+
for g, register, pulse in compiled:
54+
result = await backend.run(register, pulse)
55+
# The only thing we can test from the result is that the values make _some_ kind of sense.
56+
assert isinstance(result, dict)
57+
for k, v in result.items():
58+
assert isinstance(k, str)
59+
assert v >= 0
60+
for c in k:
61+
assert c in {"0", "1"}

tests/test_extractors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch_geometric.data as pyg_data
66
import torch_geometric.datasets as pyg_dataset
7+
import conftest
78
from qek.data.extractors import QutipExtractor, BaseExtracted
89

910
if os.name == "posix":
@@ -16,8 +17,10 @@
1617
@pytest.mark.asyncio
1718
async def test_async_emulators() -> None:
1819
"""
19-
Test that emulators can execute without exploding (both sync and async).
20+
Test that extractors emulators can execute without exploding (both sync and async).
2021
"""
22+
conftest.preload_dataset()
23+
2124
# Load dataset
2225
original_ptcfm_data = [
2326
cast(pyg_data.Data, d) for d in pyg_dataset.TUDataset(root="dataset", name="PTC_FM")

tests/test_graphs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch_geometric.datasets as pyg_dataset
77
import torch_geometric.utils as pyg_utils
88
from torch_geometric.data import Data
9+
import conftest
910

1011
from qek.data.graphs import (
1112
BaseGraph,
@@ -18,6 +19,8 @@
1819

1920

2021
def test_graph_init() -> None:
22+
conftest.preload_dataset()
23+
2124
# Load dataset
2225
original_ptcfm_data = pyg_dataset.TUDataset(root="dataset", name="PTC_FM")
2326

0 commit comments

Comments
 (0)