@@ -569,6 +569,39 @@ def test_megablox_expert_parallelism(self):
569569 actual_output , _ , _ = self .get_moe_output (variables , hidden_states , cfg , mesh )
570570 self .assertTrue (jax .numpy .allclose (expected_output , actual_output , rtol = 1e-02 , atol = 1e-02 , equal_nan = False ))
571571
572+ @pytest .mark .tpu_only
573+ def test_ring_of_expert_and_tensor_parallelism (self ):
574+ cfg = pyconfig .initialize (
575+ [None , get_test_config_path ()],
576+ run_name = "moe_block_ring_ep_tp_test" ,
577+ enable_checkpointing = False ,
578+ model_name = "mixtral-8x7b" ,
579+ dtype = "bfloat16" ,
580+ megablox = True ,
581+ sparse_matmul = True ,
582+ per_device_batch_size = 4 , # TODO(b/450900273): sharding error if pdbs=1
583+ ici_expert_parallelism = 2 ,
584+ use_ring_of_experts = True ,
585+ ici_tensor_parallelism = 2 ,
586+ max_target_length = 128 ,
587+ )
588+
589+ rng = jax .random .PRNGKey (2345 )
590+ rng_model , rng_hidden_states = jax .random .split (rng )
591+ device_count = jax .device_count ()
592+ hidden_states = jax .random .uniform (
593+ rng_hidden_states ,
594+ (int (cfg .per_device_batch_size ) * device_count , cfg .max_target_length , cfg .base_emb_dim ),
595+ dtype = cfg .dtype ,
596+ )
597+
598+ devices_array = maxtext_utils .create_device_mesh (cfg )
599+ mesh = Mesh (devices_array , cfg .mesh_axes )
600+ with nn_partitioning .axis_rules (cfg .logical_axis_rules ):
601+ variables , expected_output = self .get_expected_output (rng_model , hidden_states , cfg , mesh )
602+ actual_output , _ , _ = self .get_moe_output (variables , hidden_states , cfg , mesh )
603+ self .assertTrue (jax .numpy .allclose (expected_output , actual_output , rtol = 1e-02 , atol = 1e-02 , equal_nan = False ))
604+
572605 @pytest .mark .tpu_only
573606 def test_moe_fsdp_two_stage_parallelism_tpu_only (self ):
574607 cfg = pyconfig .initialize (
0 commit comments