|
23 | 23 |
|
24 | 24 | from absl.testing import parameterized |
25 | 25 | from flax import nnx |
26 | | -from flax.linen import partitioning as nn_partitioning |
27 | 26 | import jax |
28 | 27 | import jax.numpy as jnp |
29 | | -from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P |
30 | | -from MaxText import max_utils |
| 28 | +from jax.sharding import AxisType, Mesh |
31 | 29 | from MaxText import maxtext_utils |
32 | 30 | from MaxText import pyconfig |
33 | 31 | from MaxText.common_types import ( |
34 | 32 | AttentionType, |
35 | 33 | DECODING_ACTIVE_SEQUENCE_INDICATOR, |
36 | | - EP_AS_CONTEXT, |
37 | 34 | MODEL_MODE_AUTOREGRESSIVE, |
38 | 35 | MODEL_MODE_PREFILL, |
39 | 36 | MODEL_MODE_TRAIN, |
40 | | - ShardMode, |
41 | 37 | ) |
42 | 38 | from MaxText.globals import MAXTEXT_PKG_DIR |
43 | 39 | from MaxText.layers.attention_mla import MLA |
44 | 40 | from MaxText.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask |
45 | 41 | from MaxText.layers.attentions import Attention |
46 | | -from MaxText.sharding import maybe_shard_with_name |
47 | 42 | import numpy as np |
48 | 43 | import pytest |
49 | 44 |
|
@@ -693,15 +688,13 @@ def test_tpu_flash_attention_context_parallel( |
693 | 688 | ) |
694 | 689 | nnx.update(attention_as_mha_flash_cp, generic_state) |
695 | 690 |
|
696 | | - mha_generic_flash_cp_output = ( |
697 | | - attention_test_util.forward_with_context_expert_parallelism( |
698 | | - cfg_cp, |
699 | | - mesh_cp, |
700 | | - attention_as_mha_flash_cp, |
701 | | - lnx, |
702 | | - decoder_segment_ids, |
703 | | - decoder_positions, |
704 | | - ) |
| 691 | + mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( |
| 692 | + cfg_cp, |
| 693 | + mesh_cp, |
| 694 | + attention_as_mha_flash_cp, |
| 695 | + lnx, |
| 696 | + decoder_segment_ids, |
| 697 | + decoder_positions, |
705 | 698 | ) |
706 | 699 |
|
707 | 700 | # This removes all sharding information and makes them standard NumPy arrays. |
@@ -1479,15 +1472,13 @@ def test_tpu_flash_attention_context_parallel( |
1479 | 1472 | rngs=self.nnx_rng, |
1480 | 1473 | ) |
1481 | 1474 | nnx.update(attention_as_mla_flash_cp, generic_state) |
1482 | | - mla_generic_flash_cp_output = ( |
1483 | | - attention_test_util.forward_with_context_expert_parallelism( |
1484 | | - cfg_cp, |
1485 | | - mesh_cp, |
1486 | | - attention_as_mla_flash_cp, |
1487 | | - lnx, |
1488 | | - decoder_segment_ids, |
1489 | | - decoder_positions, |
1490 | | - ) |
| 1475 | + mla_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( |
| 1476 | + cfg_cp, |
| 1477 | + mesh_cp, |
| 1478 | + attention_as_mla_flash_cp, |
| 1479 | + lnx, |
| 1480 | + decoder_segment_ids, |
| 1481 | + decoder_positions, |
1491 | 1482 | ) |
1492 | 1483 |
|
1493 | 1484 | # This removes all sharding information and makes them standard NumPy arrays. |
|
0 commit comments