Skip to content

Commit 6d71303

Browse files
authored
[Hack] Let's try more than once to load the dataset (#83)
1 parent 643e1cb commit 6d71303

8 files changed

+73
-70
lines changed

.github/workflows/test.yml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,6 @@ jobs:
3333
- name: Install Hatch
3434
run: |
3535
pip install hatch
36-
- name: Pre-download dataset
37-
# On macOS (and only on macOS), we encounter a strange issue:
38-
#
39-
# - loading the dataset causes a FileNotFoundException to be raised
40-
# - checking afterwards, the dataset *is* on the disk
41-
# - rerunning the process, the dataset is found.
42-
#
43-
# The running hypothesis is that pytorch-geometric downloads the
44-
# dataset asynchronously, but fails to wait until the download is
45-
# complete.
46-
#
47-
# With this small script, we force the pipeline to wait until the
48-
# download is complete. So far, this seems to solve the issue.
49-
if: ${{ matrix.os == 'macos-latest' }}
50-
run: |
51-
hatch -v run before_tests
5236
- name: Run tests
5337
run: |
5438
hatch -v run test
@@ -77,22 +61,6 @@ jobs:
7761
- name: Install Hatch
7862
run: |
7963
pip install hatch
80-
- name: Pre-download dataset
81-
if: ${{ matrix.os == 'macos-latest' }}
82-
# On macOS (and only on macOS), we encounter a strange issue:
83-
#
84-
# - loading the dataset causes a FileNotFoundException to be raised
85-
# - checking afterwards, the dataset *is* on the disk
86-
# - rerunning the process, the dataset is found.
87-
#
88-
# The running hypothesis is that pytorch-geometric downloads the
89-
# dataset asynchronously, but fails to wait until the download is
90-
# complete.
91-
#
92-
# With this small script, we force the pipeline to wait until the
93-
# download is complete. So far, this seems to solve the issue.
94-
run: |
95-
hatch -v run before_tests
9664
- name: Copy samples
9765
run: |
9866
cp examples/ptcfm_processed_dataset.json .

examples/tutorial 1 - Using a Quantum Device to Extract Machine-Learning Features.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@
3636
"source": [
3737
"import torch_geometric.datasets as pyg_dataset\n",
3838
"\n",
39-
"# Load the original PTC-FM dataset\n",
40-
"og_ptcfm = pyg_dataset.TUDataset(root=\"dataset\", name=\"PTC_FM\")\n",
39+
"from qek.shared.retrier import PygRetrier\n",
40+
"\n",
41+
"# Load the original PTC-FM dataset.\n",
42+
"# We use PygRetrier to retry the download if it fails.\n",
43+
"og_ptcfm = PygRetrier().insist(pyg_dataset.TUDataset, root=\"dataset\", name=\"PTC_FM\")\n",
4144
"\n",
4245
"display(\"Loaded %s samples\" % (len(og_ptcfm), ))"
4346
]

examples/tutorial 1a - Using a Quantum Device to Extract Machine-Learning Features - low-level.ipynb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
"source": [
3838
"# Load the original PTC-FM dataset\n",
3939
"import torch_geometric.datasets as pyg_dataset\n",
40-
"og_ptcfm = pyg_dataset.TUDataset(root=\"dataset\", name=\"PTC_FM\")\n",
40+
"from qek.shared.retrier import PygRetrier\n",
41+
"\n",
42+
"# We use PygRetrier to retry the download if it fails.\n",
43+
"og_ptcfm = PygRetrier().insist(pyg_dataset.TUDataset, root=\"dataset\", name=\"PTC_FM\")\n",
4144
"\n",
4245
"display(\"Loaded %s samples\" % (len(og_ptcfm), ))"
4346
]

qek/shared/retrier.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Backoff-and-retry utilities.
3+
"""
4+
5+
import logging
6+
from time import sleep
7+
from typing import Any, Final, Type
8+
9+
from torch_geometric.data import Dataset
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class PygRetrier:
15+
"""
16+
Our test harness attempts to run tests concurrently, but the pyg dataset loader does not
17+
work well with concurrency.
18+
19+
We work around this by simply retrying the loads a few times, until it succeeds.
20+
"""
21+
22+
def __init__(self, max_attempts: int = 3, name: str = "PygRetrier"):
23+
"""
24+
Create a PygRetrier
25+
26+
Arguments:
27+
max_attempts (optional): The max number of attempts to undertake before
28+
giving up. Defaults to 3.
29+
name (optional): A name to use during logging.
30+
"""
31+
self._max_attempts: Final[int] = max_attempts
32+
self.name: Final[str] = name
33+
34+
def insist(self, callback: Type[Dataset], **kwargs: Any) -> Dataset:
35+
"""
36+
Attempt to call a function or constructor repeatedly until, hopefully,
37+
it works.
38+
"""
39+
exn: FileNotFoundError | RuntimeError | OSError | None = None
40+
result = None
41+
for i in range(self._max_attempts):
42+
sleep(i * i)
43+
try:
44+
logger.debug("%s: attempt %s", self.name, i + 1)
45+
result = callback(**kwargs) # type: ignore
46+
logger.debug("%s: attempt %s succeeded", self.name, i + 1)
47+
exn = None
48+
break
49+
except (FileNotFoundError, RuntimeError, OSError) as e:
50+
logger.warning("%s: attempt %s failed: %s", self.name, i + 1, e)
51+
exn = e
52+
if exn is not None:
53+
logger.warning("%s: all attempts failed, bailing out", self.name)
54+
raise exn
55+
assert result is not None
56+
return result

tests/conftest.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/test_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import os
44
import pulser as pl
55
import pytest
6-
import conftest
76
import torch_geometric.data as pyg_data
87
import torch_geometric.datasets as pyg_dataset
98
from qek.backends import CompilationError, QutipBackend, BaseBackend
109
import qek.data.graphs as qek_graphs
10+
from qek.shared.retrier import PygRetrier
1111

1212
if os.name == "posix":
1313
# As of this writing, emu-mps only works under Unix.
@@ -19,11 +19,11 @@ async def test_async_emulators() -> None:
1919
"""
2020
Test that backends based on emulators can execute without exploding (async).
2121
"""
22-
conftest.preload_dataset()
2322

2423
# Load dataset
2524
original_ptcfm_data = [
26-
cast(pyg_data.Data, d) for d in pyg_dataset.TUDataset(root="dataset", name="PTC_FM")
25+
cast(pyg_data.Data, d)
26+
for d in PygRetrier().insist(pyg_dataset.TUDataset, root="dataset", name="PTC_FM")
2727
]
2828

2929
compiled: list[tuple[qek_graphs.BaseGraph, pl.Register, pl.Pulse]] = []

tests/test_extractors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytest
55
import torch_geometric.data as pyg_data
66
import torch_geometric.datasets as pyg_dataset
7-
import conftest
87
from qek.data.extractors import QutipExtractor, BaseExtracted
8+
from qek.shared.retrier import PygRetrier
99

1010
if os.name == "posix":
1111
# As of this writing, emu-mps only works under Unix.
@@ -19,11 +19,10 @@ async def test_async_emulators() -> None:
1919
"""
2020
Test that extractors emulators can execute without exploding (both sync and async).
2121
"""
22-
conftest.preload_dataset()
23-
2422
# Load dataset
2523
original_ptcfm_data = [
26-
cast(pyg_data.Data, d) for d in pyg_dataset.TUDataset(root="dataset", name="PTC_FM")
24+
cast(pyg_data.Data, d)
25+
for d in PygRetrier().insist(pyg_dataset.TUDataset, root="dataset", name="PTC_FM")
2726
]
2827
MAX_NUMBER_OF_SAMPLES = 5
2928
MAX_NUMBER_OF_QUBITS = 5

tests/test_graphs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch_geometric.datasets as pyg_dataset
77
import torch_geometric.utils as pyg_utils
88
from torch_geometric.data import Data
9-
import conftest
109

1110
from qek.data.graphs import (
1211
BaseGraph,
@@ -16,13 +15,12 @@
1615
NXGraphCompiler,
1716
NXWithPos,
1817
)
18+
from qek.shared.retrier import PygRetrier
1919

2020

2121
def test_graph_init() -> None:
22-
conftest.preload_dataset()
23-
2422
# Load dataset
25-
original_ptcfm_data = pyg_dataset.TUDataset(root="dataset", name="PTC_FM")
23+
original_ptcfm_data = PygRetrier().insist(pyg_dataset.TUDataset, root="dataset", name="PTC_FM")
2624

2725
# Check that `add_graph_coord` doesn't break with this dataset.
2826

0 commit comments

Comments
 (0)