@@ -58,15 +58,9 @@ def fetch_weights(params, dtype):
5858 params ["DeepSeekMoeBlock_0" ]["MoeBlock_0" ]["wo" ],
5959 ),
6060 (
61- params ["DeepSeekMoeBlock_0" ]["shared_experts" ]["wi_0" ][
62- "kernel"
63- ],
64- params ["DeepSeekMoeBlock_0" ]["shared_experts" ]["wi_1" ][
65- "kernel"
66- ],
67- params ["DeepSeekMoeBlock_0" ]["shared_experts" ]["wo" ][
68- "kernel"
69- ],
61+ params ["DeepSeekMoeBlock_0" ]["shared_experts" ]["wi_0" ]["kernel" ],
62+ params ["DeepSeekMoeBlock_0" ]["shared_experts" ]["wi_1" ]["kernel" ],
63+ params ["DeepSeekMoeBlock_0" ]["shared_experts" ]["wo" ]["kernel" ],
7064 ),
7165 ),
7266 ),
@@ -201,11 +195,11 @@ def batch_split_schedule(
201195
202196
203197def staggered_call (fn , xs ):
204- for i in range ( len ( xs ) ):
198+ for i , x in enumerate ( xs ):
205199 if i == len (xs ) - 1 :
206- xs [i ] = fn (xs [ i ] )
200+ xs [i ] = fn (x )
207201 else :
208- xs [i ], xs [i + 1 ] = jax .lax .optimization_barrier ((fn (xs [ i ] ), xs [i + 1 ]))
202+ xs [i ], xs [i + 1 ] = jax .lax .optimization_barrier ((fn (x ), xs [i + 1 ]))
209203 return xs
210204
211205
@@ -215,9 +209,7 @@ def with_data_parallel_constraint(x, mesh):
215209 None ,
216210 None ,
217211 )
218- return jax .lax .with_sharding_constraint (
219- x , jax .NamedSharding (mesh , activation_pspec )
220- )
212+ return jax .lax .with_sharding_constraint (x , jax .NamedSharding (mesh , activation_pspec ))
221213
222214
223215def dot (x , y , axes = 1 ):
@@ -290,9 +282,7 @@ def fn(args):
290282 dtype = dtype ,
291283 )
292284
293- return staggered_call (
294- fn , list (zip (inputs , decoder_segment_ids , decoder_positions ))
295- )
285+ return staggered_call (fn , list (zip (inputs , decoder_segment_ids , decoder_positions )))
296286
297287
298288def mla (
@@ -484,9 +474,7 @@ def kv_projection(
484474 )
485475
486476
487- def get_key_value (
488- low_rank_main , key_rope , wkv_b_weights , * , qk_nope_head_dim , num_query_heads
489- ):
477+ def get_key_value (low_rank_main , key_rope , wkv_b_weights , * , qk_nope_head_dim , num_query_heads ):
490478 """Gets key and value from compressed KV latent vector and key rope."""
491479 kv_out = dot (low_rank_main , wkv_b_weights )
492480
@@ -541,20 +529,13 @@ def yarn(
541529 half_dim = embedding_dims // 2
542530 # Compute base frequencies for each (even-indexed) dimension.
543531 # (Note: We use jnp.arange with float32 for precision.)
544- freqs = 1.0 / (
545- rope_theta
546- ** (2.0 * jnp .arange (0 , half_dim , dtype = jnp .float32 ) / embedding_dims )
547- )
532+ freqs = 1.0 / (rope_theta ** (2.0 * jnp .arange (0 , half_dim , dtype = jnp .float32 ) / embedding_dims ))
548533
549534 low = (
550- embedding_dims
551- * math .log (original_max_position_embeddings / (beta_fast * 2 * math .pi ))
552- / (2 * math .log (rope_theta ))
535+ embedding_dims * math .log (original_max_position_embeddings / (beta_fast * 2 * math .pi )) / (2 * math .log (rope_theta ))
553536 )
554537 high = (
555- embedding_dims
556- * math .log (original_max_position_embeddings / (beta_slow * 2 * math .pi ))
557- / (2 * math .log (rope_theta ))
538+ embedding_dims * math .log (original_max_position_embeddings / (beta_slow * 2 * math .pi )) / (2 * math .log (rope_theta ))
558539 )
559540 low = max (math .floor (low ), 0 )
560541 high = min (math .ceil (high ), embedding_dims - 1 )
@@ -565,9 +546,7 @@ def yarn(
565546 freqs = freqs / rope_factor * (1 - smooth ) + freqs * smooth
566547
567548 # Precompute frequencies for all positions by taking the outer product.
568- t = jnp .arange (
569- max_position_embeddings , dtype = jnp .float32
570- ) # shape [max_position_embeddings]
549+ t = jnp .arange (max_position_embeddings , dtype = jnp .float32 ) # shape [max_position_embeddings]
571550 # This gives a [max_position_embeddings, half_dim] tensor with rows as time steps.
572551 freqs = jnp .outer (t , freqs )
573552
@@ -578,9 +557,7 @@ def yarn(
578557 freqs = freqs [:, :, jnp .newaxis , :] # shape: [B, S, 1, half_dim]
579558 freqs = jnp .repeat (freqs , 2 , axis = - 1 ) # shape: [B, S, 1, embedding_dims]
580559 # inputs @ mask: [B, S, N, embedding_dims] @ [embedding_dims, embedding_dims] -> [B, S, N, embedding_dims]
581- output = inputs * jnp .cos (freqs ) + jnp .matmul (
582- inputs , pairwise_swap_and_negate_mask
583- ) * jnp .sin (freqs )
560+ output = inputs * jnp .cos (freqs ) + jnp .matmul (inputs , pairwise_swap_and_negate_mask ) * jnp .sin (freqs )
584561 return output .astype (fprop_dtype )
585562
586563
@@ -671,9 +648,7 @@ def route(
671648 # Communicate local results across the expert axis.
672649 x = jax .lax .all_gather (x , axis_name = expert_axis_name , tiled = True )
673650 weights = jax .lax .all_gather (weights , axis_name = expert_axis_name , tiled = True )
674- selected_experts = jax .lax .all_gather (
675- selected_experts , axis_name = expert_axis_name , tiled = True
676- )
651+ selected_experts = jax .lax .all_gather (selected_experts , axis_name = expert_axis_name , tiled = True )
677652 group_sizes = jax .lax .psum (group_sizes , axis_name = expert_axis_name )
678653
679654 # Sort the gathered tokens and weights.
@@ -703,14 +678,10 @@ def unroute(
703678 )
704679
705680 # Sum across expert shards.
706- return jax .lax .psum_scatter (
707- x , expert_axis_name , scatter_dimension = 0 , tiled = True
708- )
681+ return jax .lax .psum_scatter (x , expert_axis_name , scatter_dimension = 0 , tiled = True )
709682
710683
711- def compute (
712- x , w0 , w1 , wo , group_sizes , weights , * , wi_tile_size , wo_tile_size , dtype
713- ):
684+ def compute (x , w0 , w1 , wo , group_sizes , weights , * , wi_tile_size , wo_tile_size , dtype ):
714685 """Processes routed tokens through the MLP."""
715686 gmm_fn = functools .partial (
716687 megablox .gmm ,
@@ -747,9 +718,7 @@ def route_compute_unroute(
747718
748719 def route_fn (inputs ):
749720 # Shared expert.
750- y = dot (
751- jax .nn .silu (dot (inputs , shared_w0 )) * dot (inputs , shared_w1 ), shared_wo
752- )
721+ y = dot (jax .nn .silu (dot (inputs , shared_w0 )) * dot (inputs , shared_w1 ), shared_wo )
753722
754723 inputs = jnp .reshape (inputs , (- 1 , inputs .shape [- 1 ]))
755724 selected_experts , weights , group_sizes = expert_selection (
0 commit comments