Skip to content

Commit 556fa61

Browse files
add cpu_only
1 parent 8b39572 commit 556fa61

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

tests/unit/sharding_compare_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na
110110

111111

112112
# Requires JAX TPU support to generate the simulated TPU topology.
113+
@pytest.mark.cpu_only
113114
@pytest.mark.tpu_backend
114115
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
115116
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
@@ -214,6 +215,7 @@ def abstract_state_and_shardings(request):
214215
return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings
215216

216217

218+
@pytest.mark.cpu_only
217219
@pytest.mark.tpu_backend
218220
class TestGetAbstractState:
219221
"""Test class for get_abstract_state function and sharding comparison."""

0 commit comments

Comments
 (0)