@@ -71,7 +71,8 @@ def fetch_weights(params, dtype):
7171@jax .named_scope ("deepseek_batchsplit_split" )
7272def split (x , split_factor = 2 ):
7373 """Splits the input into `split_factor` parts along the batch dimension."""
74-
74+ if split_factor == 1 :
75+ return [x ]
7576 if x is None :
7677 return [None ] * split_factor
7778 else :
@@ -80,8 +81,10 @@ def split(x, split_factor=2):
8081
8182
8283@jax .named_scope ("deepseek_batchsplit_merge" )
83- def merge (x ):
84+ def merge (x , split_factor = 2 ):
8485 """Merges the input microbatches back into a single tensor."""
86+ if split_factor == 1 :
87+ return x [0 ]
8588 x = jnp .stack (x , axis = 1 )
8689 return jnp .reshape (x , (- 1 ,) + x .shape [2 :])
8790
@@ -104,13 +107,13 @@ def batch_split_schedule(
104107 None ,
105108 )
106109 xs = jax .shard_map (
107- split ,
110+ functools . partial ( split , split_factor = cfg . batch_split_factor ) ,
108111 mesh = mesh ,
109112 in_specs = activation_pspec ,
110- out_specs = [activation_pspec , activation_pspec ] ,
113+ out_specs = [activation_pspec ] * cfg . batch_split_factor ,
111114 )(inputs )
112- dpos = split (positions )
113- dseg = split (segment_ids )
115+ dpos = split (positions , split_factor = cfg . batch_split_factor )
116+ dseg = split (segment_ids , split_factor = cfg . batch_split_factor )
114117 xs = [with_data_parallel_constraint (x , mesh ) for x in xs ]
115118 xs = jax .ad_checkpoint .checkpoint_name (xs , "decoder_layer_input" )
116119
@@ -186,9 +189,9 @@ def batch_split_schedule(
186189 dtype = cfg .dtype ,
187190 )
188191 xs = jax .shard_map (
189- merge ,
192+ functools . partial ( merge , split_factor = cfg . batch_split_factor ) ,
190193 mesh = mesh ,
191- in_specs = ([activation_pspec , activation_pspec ] ,),
194+ in_specs = ([activation_pspec ] * cfg . batch_split_factor ,),
192195 out_specs = activation_pspec ,
193196 )(xs )
194197 return xs
0 commit comments