2323import pytest
2424from maxtext .common .gcloud_stub import is_decoupled
2525import jax
26+ import importlib .util
2627
2728# Configure JAX to use unsafe_rbg PRNG implementation to match main scripts.
2829if is_decoupled ():
4243GCP_MARKERS = {"external_serving" , "external_training" }
4344
4445
46+ def _has_tpu_backend_support () -> bool :
47+ """Whether JAX has TPU backend support installed (PJRT TPU plugin).
48+
49+ This is intentionally *not* the same as having TPU hardware available.
50+ """
51+ try :
52+ if importlib .util .find_spec ("jaxlib" ) is None :
53+ return False
54+ except Exception : # pragma: no cover pylint: disable=broad-exception-caught
55+ return False
56+
57+ # Heuristic: TPU-enabled jaxlib exposes a TPU client/extension module.
58+ # This check avoids initializing any backend at collection time.
59+ for mod in ("jaxlib.tpu_client" , "jaxlib._src.tpu_client" , "jaxlib.tpu_extension" ):
60+ try :
61+ if importlib .util .find_spec (mod ) is not None :
62+ return True
63+ except Exception : # pragma: no cover pylint: disable=broad-exception-caught
64+ continue
65+ return False
66+
67+
4568def pytest_collection_modifyitems (config , items ):
4669 """Customize pytest collection behavior.
4770
@@ -55,12 +78,21 @@ def pytest_collection_modifyitems(config, items):
5578
5679 skip_no_tpu = None
5780 skip_no_gpu = None
81+ skip_no_tpu_backend = None
5882 if not _HAS_TPU :
5983 skip_no_tpu = pytest .mark .skip (reason = "Skipped: requires TPU hardware, none detected" )
6084
6185 if not _HAS_GPU :
6286 skip_no_gpu = pytest .mark .skip (reason = "Skipped: requires GPU hardware, none detected" )
6387
88+ if not _has_tpu_backend_support ():
89+ skip_no_tpu_backend = pytest .mark .skip (
90+ reason = (
91+ "Skipped: requires a TPU-enabled JAX install (TPU PJRT plugin). "
92+ "Install a TPU-enabled jax/jaxlib build to run this test."
93+ )
94+ )
95+
6496 for item in items :
6597 # Iterate thru the markers of every test.
6698 cur_test_markers = {m .name for m in item .iter_markers ()}
@@ -76,6 +108,11 @@ def pytest_collection_modifyitems(config, items):
76108 remaining .append (item )
77109 continue
78110
111+ if skip_no_tpu_backend and "tpu_backend" in cur_test_markers :
112+ item .add_marker (skip_no_tpu_backend )
113+ remaining .append (item )
114+ continue
115+
79116 if decoupled and (cur_test_markers & GCP_MARKERS ):
80117 # Deselect tests marked as external_serving/training entirely.
81118 deselected .append (item )
@@ -98,6 +135,7 @@ def pytest_configure(config):
98135 for m in [
99136 "gpu_only: tests that require GPU hardware" ,
100137 "tpu_only: tests that require TPU hardware" ,
138+ "tpu_backend: tests that require a TPU-enabled JAX install (TPU PJRT plugin), but not TPU hardware" ,
101139 "external_serving: JetStream / serving / decode server components" ,
102140 "external_training: goodput integrations" ,
103141 "decoupled: marked on tests that are not skipped due to GCP deps, when DECOUPLE_GCLOUD=TRUE" ,
0 commit comments