@@ -368,6 +368,11 @@ def __init__(
368368 else :
369369 self ._tensor_parallelism_name = "tensor"
370370
371+ if self .config .attention == "vllm_rpa" :
372+ self ._expert_parallelism_name = "attn_dp_expert"
373+ else :
374+ self ._expert_parallelism_name = "expert"
375+
371376 self .gate = GateLogit (
372377 in_features_shape = self .config .emb_dim ,
373378 out_features_shape = self .num_experts ,
@@ -467,7 +472,7 @@ def _logical_to_mesh_axes(self, logical_name):
467472 return logical_to_mesh_axes (logical_name , mesh = self .mesh , rules = logical_rules )
468473
469474 def get_expert_parallelism_size (self ):
470- return self .mesh .shape .get ("expert" , 1 )
475+ return self .mesh .shape .get (self . _expert_parallelism_name , 1 )
471476
472477 def get_tensor_parallelism_size (self ):
473478 if isinstance (self ._tensor_parallelism_name , tuple ):
@@ -494,8 +499,8 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
494499 if self .config .use_random_routing :
495500 if rngs is None :
496501 raise ValueError ("The random key cannot be None for random routing." )
497- # Reuse the 'dropout ' RNG stream to ensure random routing
498- rng = rngs .dropout ()
502+ # Reuse the 'params ' RNG stream to ensure random routing
503+ rng = rngs .params ()
499504 top_k_weights , top_k_indices = random_routing (rng , gate_logits , self .num_experts_per_tok )
500505 return top_k_weights , top_k_indices
501506
@@ -1002,7 +1007,7 @@ def gmm(
10021007 # batch_size=1 while decode can have batch_size > 1.
10031008 try :
10041009 is_batch_sharded_by_expert = (
1005- "expert"
1010+ self . _expert_parallelism_name
10061011 in tuple (
10071012 filter (
10081013 lambda tup : tup [0 ] == "activation_batch" ,
@@ -1094,10 +1099,9 @@ def gmm(
10941099 )
10951100 def wrapper (x , logits , pre_bias_logits , w0 , w1 , wo , w0_bias , w1_bias , wo_bias , rngs ):
10961101 batch_size , sequence_length , _ = x .shape
1097- expert_axis_name = "expert"
10981102 num_expert_parallelism = self .get_expert_parallelism_size ()
10991103 if num_expert_parallelism > 1 :
1100- expert_shard_id = jax .lax .axis_index (expert_axis_name )
1104+ expert_shard_id = jax .lax .axis_index (self . _expert_parallelism_name )
11011105 else :
11021106 expert_shard_id = 0
11031107 num_expert_parallelism = self .get_expert_parallelism_size ()
@@ -1107,7 +1111,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11071111
11081112 # Duplicate inputs to all expert shards.
11091113 x , logits , pre_bias_logits = tuple (
1110- jax .lax .all_gather (z , axis_name = expert_axis_name , tiled = True ) for z in (x , logits , pre_bias_logits )
1114+ jax .lax .all_gather (z , axis_name = self ._expert_parallelism_name , tiled = True )
1115+ for z in (x , logits , pre_bias_logits )
11111116 )
11121117
11131118 # "Route" tokens within each shard.
@@ -1131,7 +1136,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11311136 )
11321137
11331138 if num_expert_parallelism > 1 :
1134- batch_axis = "expert" if is_batch_sharded_by_expert else "data"
1139+ batch_axis = self . _expert_parallelism_name if is_batch_sharded_by_expert else "data"
11351140 # get group sizes for all shards
11361141 local_expert_size = self .config .num_experts // num_expert_parallelism
11371142 reshaped_group_sizes = jnp .sum (group_sizes .reshape (- 1 , local_expert_size ), axis = 1 )
@@ -1163,9 +1168,9 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
11631168 send_sizes ,
11641169 output_offsets ,
11651170 recv_sizes ,
1166- axis_name = expert_axis_name ,
1171+ axis_name = self . _expert_parallelism_name ,
11671172 )
1168- global_group_sizes = jax .lax .all_gather (group_sizes , axis_name = expert_axis_name )
1173+ global_group_sizes = jax .lax .all_gather (group_sizes , axis_name = self . _expert_parallelism_name )
11691174 x , local_sorted_indices , group_sizes , selected_experts = RoutedMoE .local_permute (
11701175 x ,
11711176 global_group_sizes ,
@@ -1310,7 +1315,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13101315
13111316 # Sum up the partial outputs across the expert shards.
13121317 output = jnp .reshape (output , (- 1 , sequence_length , self .config .emb_dim ))
1313- output = jax .lax .psum_scatter (output , expert_axis_name , scatter_dimension = 0 , tiled = True )
1318+ output = jax .lax .psum_scatter (output , self . _expert_parallelism_name , scatter_dimension = 0 , tiled = True )
13141319
13151320 else :
13161321 if num_expert_parallelism > 1 :
@@ -1343,7 +1348,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13431348 send_sizes ,
13441349 output_offsets ,
13451350 recv_sizes ,
1346- axis_name = expert_axis_name ,
1351+ axis_name = self . _expert_parallelism_name ,
13471352 )
13481353 else :
13491354 # If bach is replicated across EP shards then each shard should send
@@ -1363,7 +1368,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13631368 send_sizes ,
13641369 output_offsets ,
13651370 recv_sizes ,
1366- axis_name = expert_axis_name ,
1371+ axis_name = self . _expert_parallelism_name ,
13671372 )
13681373
13691374 output = self .unpermute (
0 commit comments