Skip to content

Commit 018b34d

Browse files
Merge pull request #3242 from ROCm:tpu-backend-skips
PiperOrigin-RevId: 875996905
2 parents fcaecd2 + cf9f9a3 commit 018b34d

4 files changed

Lines changed: 42 additions & 1 deletion

File tree

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ addopts =
2424
--ignore=tests/unit/engram_vs_reference_test.py
2525
markers =
2626
tpu_only: marks tests to be run on TPUs only
27+
tpu_backend: marks tests that require a TPU-enabled JAX install (TPU PJRT plugin), but not TPU hardware
2728
gpu_only: marks tests to be run on GPUs only
2829
cpu_only: marks tests to be run on CPUs only
2930
decoupled: tests that validate offline / DECOUPLE_GCLOUD=TRUE mode.

tests/conftest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pytest
2424
from maxtext.common.gcloud_stub import is_decoupled
2525
import jax
26+
import importlib.util
2627

2728
# Configure JAX to use unsafe_rbg PRNG implementation to match main scripts.
2829
if is_decoupled():
@@ -42,6 +43,28 @@
4243
GCP_MARKERS = {"external_serving", "external_training"}
4344

4445

46+
def _has_tpu_backend_support() -> bool:
47+
"""Whether JAX has TPU backend support installed (PJRT TPU plugin).
48+
49+
This is intentionally *not* the same as having TPU hardware available.
50+
"""
51+
try:
52+
if importlib.util.find_spec("jaxlib") is None:
53+
return False
54+
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
55+
return False
56+
57+
# Heuristic: TPU-enabled jaxlib exposes a TPU client/extension module.
58+
# This check avoids initializing any backend at collection time.
59+
for mod in ("jaxlib.tpu_client", "jaxlib._src.tpu_client", "jaxlib.tpu_extension"):
60+
try:
61+
if importlib.util.find_spec(mod) is not None:
62+
return True
63+
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
64+
continue
65+
return False
66+
67+
4568
def pytest_collection_modifyitems(config, items):
4669
"""Customize pytest collection behavior.
4770
@@ -55,12 +78,21 @@ def pytest_collection_modifyitems(config, items):
5578

5679
skip_no_tpu = None
5780
skip_no_gpu = None
81+
skip_no_tpu_backend = None
5882
if not _HAS_TPU:
5983
skip_no_tpu = pytest.mark.skip(reason="Skipped: requires TPU hardware, none detected")
6084

6185
if not _HAS_GPU:
6286
skip_no_gpu = pytest.mark.skip(reason="Skipped: requires GPU hardware, none detected")
6387

88+
if not _has_tpu_backend_support():
89+
skip_no_tpu_backend = pytest.mark.skip(
90+
reason=(
91+
"Skipped: requires a TPU-enabled JAX install (TPU PJRT plugin). "
92+
"Install a TPU-enabled jax/jaxlib build to run this test."
93+
)
94+
)
95+
6496
for item in items:
6597
# Iterate thru the markers of every test.
6698
cur_test_markers = {m.name for m in item.iter_markers()}
@@ -76,6 +108,11 @@ def pytest_collection_modifyitems(config, items):
76108
remaining.append(item)
77109
continue
78110

111+
if skip_no_tpu_backend and "tpu_backend" in cur_test_markers:
112+
item.add_marker(skip_no_tpu_backend)
113+
remaining.append(item)
114+
continue
115+
79116
if decoupled and (cur_test_markers & GCP_MARKERS):
80117
# Deselect tests marked as external_serving/training entirely.
81118
deselected.append(item)
@@ -98,6 +135,7 @@ def pytest_configure(config):
98135
for m in [
99136
"gpu_only: tests that require GPU hardware",
100137
"tpu_only: tests that require TPU hardware",
138+
"tpu_backend: tests that require a TPU-enabled JAX install (TPU PJRT plugin), but not TPU hardware",
101139
"external_serving: JetStream / serving / decode server components",
102140
"external_training: goodput integrations",
103141
"decoupled: marked on tests that are not skipped due to GCP deps, when DECOUPLE_GCLOUD=TRUE",

tests/unit/sharding_compare_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
Transformer = models.transformer_as_linen
3535

36+
pytestmark = [pytest.mark.cpu_only, pytest.mark.tpu_backend]
37+
3638

3739
def compute_checksum(d: dict) -> str:
3840
"""Compute a checksum (SHA256) of a dictionary."""

tests/unit/train_compile_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
2929
from tests.utils.test_helpers import get_test_config_path
3030

31-
pytestmark = [pytest.mark.external_training]
31+
pytestmark = [pytest.mark.external_training, pytest.mark.tpu_backend]
3232

3333

3434
class TrainCompile(unittest.TestCase):

0 commit comments

Comments
 (0)