Skip to content

Commit f2d2ec8

Browse files
Merge pull request #3355 from ROCm:fix-tpu-backend-skips
PiperOrigin-RevId: 881535168
2 parents f1e2f02 + d3c1977 commit f2d2ec8

3 files changed

Lines changed: 10 additions & 11 deletions

File tree

tests/conftest.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,11 @@ def _has_tpu_backend_support() -> bool:
5454
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
5555
return False
5656

57-
# Heuristic: TPU-enabled jaxlib exposes a TPU client/extension module.
58-
# This check avoids initializing any backend at collection time.
59-
for mod in ("jaxlib.tpu_client", "jaxlib._src.tpu_client", "jaxlib.tpu_extension"):
60-
try:
61-
if importlib.util.find_spec(mod) is not None:
62-
return True
63-
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
64-
continue
65-
return False
57+
# Heuristic: TPU backend support is provided via the `libtpu` package.
58+
try:
59+
return importlib.util.find_spec("libtpu") is not None
60+
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
61+
return False
6662

6763

6864
def pytest_collection_modifyitems(config, items):

tests/unit/sharding_compare_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na
110110

111111

112112
# Requires JAX TPU support to generate the simulated TPU topology.
113-
@pytest.mark.tpu_only
113+
@pytest.mark.cpu_only
114+
@pytest.mark.tpu_backend
114115
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
115116
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
116117
"""
@@ -214,11 +215,12 @@ def abstract_state_and_shardings(request):
214215
return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings
215216

216217

218+
@pytest.mark.cpu_only
219+
@pytest.mark.tpu_backend
217220
class TestGetAbstractState:
218221
"""Test class for get_abstract_state function and sharding comparison."""
219222

220223
# Requires JAX TPU support to generate the simulated TPU topology.
221-
@pytest.mark.tpu_only
222224
def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pylint: disable=redefined-outer-name
223225
"""Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
224226

tests/unit/train_compile_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tests.utils.test_helpers import get_test_config_path
3030

3131

32+
@pytest.mark.tpu_backend
3233
class TrainCompile(unittest.TestCase):
3334
"""Tests for the Ahead of Time Compilation functionality, train_compile.py"""
3435

0 commit comments

Comments
 (0)