Skip to content

Commit 98a6589

Browse files
fix tpu_backend_support checking
1 parent 2f8c473 commit 98a6589

3 files changed

Lines changed: 9 additions & 9 deletions

File tree

tests/conftest.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,11 @@ def _has_tpu_backend_support() -> bool:
5454
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
5555
return False
5656

57-
# Heuristic: TPU-enabled jaxlib exposes a TPU client/extension module.
58-
# This check avoids initializing any backend at collection time.
59-
for mod in ("jaxlib.tpu_client", "jaxlib._src.tpu_client", "jaxlib.tpu_extension"):
60-
try:
61-
if importlib.util.find_spec(mod) is not None:
62-
return True
63-
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
64-
continue
65-
return False
57+
# Heuristic: TPU backend support is provided via the `libtpu` package.
58+
try:
59+
return importlib.util.find_spec("libtpu") is not None
60+
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
61+
return False
6662

6763

6864
def pytest_collection_modifyitems(config, items):

tests/unit/sharding_compare_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
Transformer = models.transformer_as_linen
3535

36+
pytestmark = [pytest.mark.cpu_only, pytest.mark.tpu_backend]
37+
3638

3739
def compute_checksum(d: dict) -> str:
3840
"""Compute a checksum (SHA256) of a dictionary."""

tests/unit/train_compile_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
2929
from tests.utils.test_helpers import get_test_config_path
3030

31+
pytestmark = [pytest.mark.tpu_backend]
32+
3133

3234
class TrainCompile(unittest.TestCase):
3335
"""Tests for the Ahead of Time Compilation functionality, train_compile.py"""

0 commit comments

Comments
 (0)