Skip to content

Commit 2c2b92e

Browse files
committed
make gpu zero1 test scheduled
1 parent aee1a74 commit 2c2b92e

2 files changed

Lines changed: 16 additions & 24 deletions

File tree

tests/attention_test.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,22 @@
2323

2424
from absl.testing import parameterized
2525
from flax import nnx
26-
from flax.linen import partitioning as nn_partitioning
2726
import jax
2827
import jax.numpy as jnp
29-
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P
30-
from MaxText import max_utils
28+
from jax.sharding import AxisType, Mesh
3129
from MaxText import maxtext_utils
3230
from MaxText import pyconfig
3331
from MaxText.common_types import (
3432
AttentionType,
3533
DECODING_ACTIVE_SEQUENCE_INDICATOR,
36-
EP_AS_CONTEXT,
3734
MODEL_MODE_AUTOREGRESSIVE,
3835
MODEL_MODE_PREFILL,
3936
MODEL_MODE_TRAIN,
40-
ShardMode,
4137
)
4238
from MaxText.globals import MAXTEXT_PKG_DIR
4339
from MaxText.layers.attention_mla import MLA
4440
from MaxText.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask
4541
from MaxText.layers.attentions import Attention
46-
from MaxText.sharding import maybe_shard_with_name
4742
import numpy as np
4843
import pytest
4944

@@ -693,15 +688,13 @@ def test_tpu_flash_attention_context_parallel(
693688
)
694689
nnx.update(attention_as_mha_flash_cp, generic_state)
695690

696-
mha_generic_flash_cp_output = (
697-
attention_test_util.forward_with_context_expert_parallelism(
698-
cfg_cp,
699-
mesh_cp,
700-
attention_as_mha_flash_cp,
701-
lnx,
702-
decoder_segment_ids,
703-
decoder_positions,
704-
)
691+
mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism(
692+
cfg_cp,
693+
mesh_cp,
694+
attention_as_mha_flash_cp,
695+
lnx,
696+
decoder_segment_ids,
697+
decoder_positions,
705698
)
706699

707700
# This removes all sharding information and makes them standard NumPy arrays.
@@ -1479,15 +1472,13 @@ def test_tpu_flash_attention_context_parallel(
14791472
rngs=self.nnx_rng,
14801473
)
14811474
nnx.update(attention_as_mla_flash_cp, generic_state)
1482-
mla_generic_flash_cp_output = (
1483-
attention_test_util.forward_with_context_expert_parallelism(
1484-
cfg_cp,
1485-
mesh_cp,
1486-
attention_as_mla_flash_cp,
1487-
lnx,
1488-
decoder_segment_ids,
1489-
decoder_positions,
1490-
)
1475+
mla_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism(
1476+
cfg_cp,
1477+
mesh_cp,
1478+
attention_as_mla_flash_cp,
1479+
lnx,
1480+
decoder_segment_ids,
1481+
decoder_positions,
14911482
)
14921483

14931484
# This removes all sharding information and makes them standard NumPy arrays.

tests/integration_tests/train_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def test_gpu_synthetic_model_ag_once(self):
417417

418418
@pytest.mark.integration_test
419419
@pytest.mark.gpu_only
420+
@pytest.mark.scheduled_only
420421
def test_gpu_zero1_gradient_accumulation(self):
421422
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
422423
zero1_ga = [ # tests Zero-1 optimizer sharding with gradient accumulation

0 commit comments

Comments
 (0)