Skip to content

Commit 8b39572

Browse files
fix markers
1 parent 98a6589 commit 8b39572

2 files changed

Lines changed: 3 additions & 6 deletions

File tree

tests/unit/sharding_compare_test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333

3434
Transformer = models.transformer_as_linen
3535

36-
pytestmark = [pytest.mark.cpu_only, pytest.mark.tpu_backend]
37-
3836

3937
def compute_checksum(d: dict) -> str:
4038
"""Compute a checksum (SHA256) of a dictionary."""
@@ -112,7 +110,7 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na
112110

113111

114112
# Requires JAX TPU support to generate the simulated TPU topology.
115-
@pytest.mark.tpu_only
113+
@pytest.mark.tpu_backend
116114
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
117115
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
118116
"""
@@ -216,11 +214,11 @@ def abstract_state_and_shardings(request):
216214
return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings
217215

218216

217+
@pytest.mark.tpu_backend
219218
class TestGetAbstractState:
220219
"""Test class for get_abstract_state function and sharding comparison."""
221220

222221
# Requires JAX TPU support to generate the simulated TPU topology.
223-
@pytest.mark.tpu_only
224222
def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pylint: disable=redefined-outer-name
225223
"""Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
226224

tests/unit/train_compile_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@
2828
from maxtext.trainers.pre_train.train_compile import main as train_compile_main
2929
from tests.utils.test_helpers import get_test_config_path
3030

31-
pytestmark = [pytest.mark.tpu_backend]
32-
3331

32+
@pytest.mark.tpu_backend
3433
class TrainCompile(unittest.TestCase):
3534
"""Tests for the Ahead of Time Compilation functionality, train_compile.py"""
3635

0 commit comments

Comments
 (0)