Skip to content

Commit e2f6b0e

Browse files
Merge pull request #3351 from AI-Hypercomputer:chengnuojin-fix-ring-tp
PiperOrigin-RevId: 880950918
2 parents 2f8c473 + e3fa572 commit e2f6b0e

2 files changed

Lines changed: 34 additions & 1 deletion

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13181318
)
13191319

13201320
# Sum up the partial outputs across the expert shards.
1321-
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim))
1321+
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size()))
13221322
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
13231323

13241324
else:

tests/unit/moe_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,39 @@ def test_megablox_expert_parallelism(self):
569569
actual_output, _, _ = self.get_moe_output(variables, hidden_states, cfg, mesh)
570570
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))
571571

572+
@pytest.mark.tpu_only
573+
def test_ring_of_expert_and_tensor_parallelism(self):
574+
cfg = pyconfig.initialize(
575+
[None, get_test_config_path()],
576+
run_name="moe_block_ring_ep_tp_test",
577+
enable_checkpointing=False,
578+
model_name="mixtral-8x7b",
579+
dtype="bfloat16",
580+
megablox=True,
581+
sparse_matmul=True,
582+
per_device_batch_size=4, # TODO(b/450900273): sharding error if pdbs=1
583+
ici_expert_parallelism=2,
584+
use_ring_of_experts=True,
585+
ici_tensor_parallelism=2,
586+
max_target_length=128,
587+
)
588+
589+
rng = jax.random.PRNGKey(2345)
590+
rng_model, rng_hidden_states = jax.random.split(rng)
591+
device_count = jax.device_count()
592+
hidden_states = jax.random.uniform(
593+
rng_hidden_states,
594+
(int(cfg.per_device_batch_size) * device_count, cfg.max_target_length, cfg.base_emb_dim),
595+
dtype=cfg.dtype,
596+
)
597+
598+
devices_array = maxtext_utils.create_device_mesh(cfg)
599+
mesh = Mesh(devices_array, cfg.mesh_axes)
600+
with nn_partitioning.axis_rules(cfg.logical_axis_rules):
601+
variables, expected_output = self.get_expected_output(rng_model, hidden_states, cfg, mesh)
602+
actual_output, _, _ = self.get_moe_output(variables, hidden_states, cfg, mesh)
603+
self.assertTrue(jax.numpy.allclose(expected_output, actual_output, rtol=1e-02, atol=1e-02, equal_nan=False))
604+
572605
@pytest.mark.tpu_only
573606
def test_moe_fsdp_two_stage_parallelism_tpu_only(self):
574607
cfg = pyconfig.initialize(

0 commit comments

Comments
 (0)