@@ -168,6 +168,126 @@ def merge(x, split_factor=2):
168168 return jnp .reshape (x , (- 1 ,) + x .shape [2 :])
169169
170170
171+ def gather_weights (weights , mesh ):
172+ """all-gathers FSDP sharded weights."""
173+
174+ def fn (weights ):
175+ (
176+ (pre_attn_norm , post_attn_norm ),
177+ (wq_a , wq_b , q_norm , wkv_a , wkv_b , kv_norm , out ),
178+ ), (
179+ (gate , bias ),
180+ (routed_wi_0 , routed_wi_1 , routed_wo ),
181+ (shared_wi_0 , shared_wi_1 , shared_wo ),
182+ ) = weights
183+ # All-gather across FSDP axis. Expert axis is used for FSDP in attention.
184+ wq_a = jax .lax .all_gather (wq_a , axis_name = "expert" , tiled = True , axis = 1 )
185+ wq_a = jax .lax .all_gather (wq_a , axis_name = "fsdp" , tiled = True )
186+ wq_b = jax .lax .all_gather (wq_b , axis_name = "expert" , tiled = True , axis = 1 )
187+ wq_b = jax .lax .all_gather (wq_b , axis_name = "fsdp" , tiled = True )
188+ wkv_a = jax .lax .all_gather (wkv_a , axis_name = "expert" , tiled = True , axis = 1 )
189+ wkv_a = jax .lax .all_gather (wkv_a , axis_name = "fsdp" , tiled = True )
190+ wkv_b = jax .lax .all_gather (wkv_b , axis_name = "expert" , tiled = True , axis = 1 )
191+ wkv_b = jax .lax .all_gather (wkv_b , axis_name = "fsdp" , tiled = True )
192+ out = jax .lax .all_gather (out , axis_name = "expert" , tiled = True )
193+ out = jax .lax .all_gather (out , axis_name = "fsdp" , tiled = True , axis = 2 )
194+ gate = jax .lax .all_gather (gate , axis_name = "fsdp" , tiled = True )
195+ routed_wi_0 = jax .lax .all_gather (routed_wi_0 , axis_name = "fsdp" , tiled = True )
196+ routed_wi_1 = jax .lax .all_gather (routed_wi_1 , axis_name = "fsdp" , tiled = True )
197+ routed_wo = jax .lax .all_gather (routed_wo , axis_name = "fsdp" , tiled = True )
198+ shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "expert" , tiled = True , axis = 1 )
199+ shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "fsdp" , tiled = True )
200+ shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "expert" , tiled = True , axis = 1 )
201+ shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "fsdp" , tiled = True )
202+ shared_wo = jax .lax .all_gather (shared_wo , axis_name = "expert" , tiled = True )
203+ shared_wo = jax .lax .all_gather (shared_wo , axis_name = "fsdp" , tiled = True , axis = 1 )
204+ return (
205+ (
206+ (pre_attn_norm , post_attn_norm ),
207+ (wq_a , wq_b , q_norm , wkv_a , wkv_b , kv_norm , out ),
208+ ),
209+ (
210+ (gate , bias ),
211+ (routed_wi_0 , routed_wi_1 , routed_wo ),
212+ (shared_wi_0 , shared_wi_1 , shared_wo ),
213+ ),
214+ )
215+
216+ return jax .shard_map (
217+ fn ,
218+ mesh = mesh ,
219+ in_specs = (
220+ (
221+ (
222+ (
223+ jax .sharding .PartitionSpec (None ),
224+ jax .sharding .PartitionSpec (None ),
225+ ),
226+ (
227+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
228+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
229+ jax .sharding .PartitionSpec (None ),
230+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
231+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
232+ jax .sharding .PartitionSpec (None ),
233+ jax .sharding .PartitionSpec ("expert" , None , "fsdp" ),
234+ ),
235+ ),
236+ (
237+ (
238+ jax .sharding .PartitionSpec ("fsdp" , None ),
239+ jax .sharding .PartitionSpec (None ),
240+ ),
241+ (
242+ jax .sharding .PartitionSpec ("fsdp" , None , "expert" ),
243+ jax .sharding .PartitionSpec ("fsdp" , None , "expert" ),
244+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
245+ ),
246+ (
247+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
248+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
249+ jax .sharding .PartitionSpec ("expert" , "fsdp" ),
250+ ),
251+ ),
252+ ),
253+ ),
254+ out_specs = (
255+ (
256+ (
257+ jax .sharding .PartitionSpec (None ),
258+ jax .sharding .PartitionSpec (None ),
259+ ),
260+ (
261+ jax .sharding .PartitionSpec (None , None ),
262+ jax .sharding .PartitionSpec (None , None , None ),
263+ jax .sharding .PartitionSpec (None ),
264+ jax .sharding .PartitionSpec (None , None ),
265+ jax .sharding .PartitionSpec (None , None , None ),
266+ jax .sharding .PartitionSpec (None ),
267+ jax .sharding .PartitionSpec (None , None , None ),
268+ ),
269+ ),
270+ (
271+ (
272+ jax .sharding .PartitionSpec (None , None ),
273+ jax .sharding .PartitionSpec (None ),
274+ ),
275+ (
276+ jax .sharding .PartitionSpec (None , None , "expert" ),
277+ jax .sharding .PartitionSpec (None , None , "expert" ),
278+ jax .sharding .PartitionSpec (None , "expert" , None ),
279+ ),
280+ (
281+ jax .sharding .PartitionSpec (None , None ),
282+ jax .sharding .PartitionSpec (None , None ),
283+ jax .sharding .PartitionSpec (None , None ),
284+ ),
285+ ),
286+ ),
287+ check_vma = False ,
288+ )(weights )
289+
290+
171291def scan_batch_split_layers (
172292 inputs ,
173293 params ,
@@ -183,6 +303,7 @@ def scan_batch_split_layers(
183303 """Scans the layers with batch-split schedule."""
184304
185305 def batch_split_scan_fn (inputs , weights , dpos , dseg ):
306+ weights = gather_weights (weights , mesh )
186307 xs = batch_split_schedule (
187308 inputs ,
188309 weights ,
0 commit comments