Skip to content

Commit 755bdb8

Browse files
committed
fix ring of experts using random routing
1 parent e2f6b0e commit 755bdb8

2 files changed

Lines changed: 25 additions & 0 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11271127
pre_bias_logits,
11281128
self.config.use_custom_sort_vjp,
11291129
roll_to_expert_id=num_experts_per_shard * expert_shard_id,
1130+
rngs=rngs,
11301131
)
11311132

11321133
# Filter down to the group sizes that apply to only the experts in the

tests/unit/train_compile_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,30 @@ def test_moe_megablox_bf16(self):
426426
)
427427
)
428428

429+
@pytest.mark.cpu_only
430+
def test_moe_megablox_ring_ep_random(self):
431+
temp_dir = gettempdir()
432+
compiled_trainstep_file = os.path.join(temp_dir, "test_moe_megablox_ring_ep_random.pickle")
433+
train_compile_main(
434+
(
435+
"",
436+
get_test_config_path(),
437+
f"compiled_trainstep_file={compiled_trainstep_file}",
438+
"compile_topology=v5p-16",
439+
"use_iota_embed=true",
440+
"compile_topology_num_slices=1",
441+
"model_name=deepseek3-test",
442+
"sparse_matmul=True",
443+
"megablox=True",
444+
"per_device_batch_size=4",
445+
"max_target_length=128",
446+
"use_ring_of_experts=True",
447+
"use_random_routing=True",
448+
"attention=flash",
449+
"dtype=bfloat16",
450+
)
451+
)
452+
429453
@pytest.mark.cpu_only
430454
def test_moe_ragged_dot_bf16(self):
431455
temp_dir = gettempdir()

0 commit comments

Comments
 (0)