Skip to content

Commit c6b84c1

Browse files
Make weight all-gathers explicit for DSv3 batch-split
PiperOrigin-RevId: 885217977
1 parent 6e47e57 commit c6b84c1

1 file changed

Lines changed: 121 additions & 0 deletions

File tree

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,126 @@ def merge(x, split_factor=2):
168168
return jnp.reshape(x, (-1,) + x.shape[2:])
169169

170170

171+
def gather_weights(weights, mesh):
172+
"""all-gathers FSDP sharded weights."""
173+
174+
def fn(weights):
175+
(
176+
(pre_attn_norm, post_attn_norm),
177+
(wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out),
178+
), (
179+
(gate, bias),
180+
(routed_wi_0, routed_wi_1, routed_wo),
181+
(shared_wi_0, shared_wi_1, shared_wo),
182+
) = weights
183+
# All-gather across FSDP axis. Expert axis is used for FSDP in attention.
184+
wq_a = jax.lax.all_gather(wq_a, axis_name="expert", tiled=True, axis=1)
185+
wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True)
186+
wq_b = jax.lax.all_gather(wq_b, axis_name="expert", tiled=True, axis=1)
187+
wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True)
188+
wkv_a = jax.lax.all_gather(wkv_a, axis_name="expert", tiled=True, axis=1)
189+
wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True)
190+
wkv_b = jax.lax.all_gather(wkv_b, axis_name="expert", tiled=True, axis=1)
191+
wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True)
192+
out = jax.lax.all_gather(out, axis_name="expert", tiled=True)
193+
out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2)
194+
gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True)
195+
routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True)
196+
routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True)
197+
routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True)
198+
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="expert", tiled=True, axis=1)
199+
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True)
200+
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="expert", tiled=True, axis=1)
201+
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True)
202+
shared_wo = jax.lax.all_gather(shared_wo, axis_name="expert", tiled=True)
203+
shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1)
204+
return (
205+
(
206+
(pre_attn_norm, post_attn_norm),
207+
(wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out),
208+
),
209+
(
210+
(gate, bias),
211+
(routed_wi_0, routed_wi_1, routed_wo),
212+
(shared_wi_0, shared_wi_1, shared_wo),
213+
),
214+
)
215+
216+
return jax.shard_map(
217+
fn,
218+
mesh=mesh,
219+
in_specs=(
220+
(
221+
(
222+
(
223+
jax.sharding.PartitionSpec(None),
224+
jax.sharding.PartitionSpec(None),
225+
),
226+
(
227+
jax.sharding.PartitionSpec("fsdp", "expert"),
228+
jax.sharding.PartitionSpec("fsdp", "expert", None),
229+
jax.sharding.PartitionSpec(None),
230+
jax.sharding.PartitionSpec("fsdp", "expert"),
231+
jax.sharding.PartitionSpec("fsdp", "expert", None),
232+
jax.sharding.PartitionSpec(None),
233+
jax.sharding.PartitionSpec("expert", None, "fsdp"),
234+
),
235+
),
236+
(
237+
(
238+
jax.sharding.PartitionSpec("fsdp", None),
239+
jax.sharding.PartitionSpec(None),
240+
),
241+
(
242+
jax.sharding.PartitionSpec("fsdp", None, "expert"),
243+
jax.sharding.PartitionSpec("fsdp", None, "expert"),
244+
jax.sharding.PartitionSpec("fsdp", "expert", None),
245+
),
246+
(
247+
jax.sharding.PartitionSpec("fsdp", "expert"),
248+
jax.sharding.PartitionSpec("fsdp", "expert"),
249+
jax.sharding.PartitionSpec("expert", "fsdp"),
250+
),
251+
),
252+
),
253+
),
254+
out_specs=(
255+
(
256+
(
257+
jax.sharding.PartitionSpec(None),
258+
jax.sharding.PartitionSpec(None),
259+
),
260+
(
261+
jax.sharding.PartitionSpec(None, None),
262+
jax.sharding.PartitionSpec(None, None, None),
263+
jax.sharding.PartitionSpec(None),
264+
jax.sharding.PartitionSpec(None, None),
265+
jax.sharding.PartitionSpec(None, None, None),
266+
jax.sharding.PartitionSpec(None),
267+
jax.sharding.PartitionSpec(None, None, None),
268+
),
269+
),
270+
(
271+
(
272+
jax.sharding.PartitionSpec(None, None),
273+
jax.sharding.PartitionSpec(None),
274+
),
275+
(
276+
jax.sharding.PartitionSpec(None, None, "expert"),
277+
jax.sharding.PartitionSpec(None, None, "expert"),
278+
jax.sharding.PartitionSpec(None, "expert", None),
279+
),
280+
(
281+
jax.sharding.PartitionSpec(None, None),
282+
jax.sharding.PartitionSpec(None, None),
283+
jax.sharding.PartitionSpec(None, None),
284+
),
285+
),
286+
),
287+
check_vma=False,
288+
)(weights)
289+
290+
171291
def scan_batch_split_layers(
172292
inputs,
173293
params,
@@ -183,6 +303,7 @@ def scan_batch_split_layers(
183303
"""Scans the layers with batch-split schedule."""
184304

185305
def batch_split_scan_fn(inputs, weights, dpos, dseg):
306+
weights = gather_weights(weights, mesh)
186307
xs = batch_split_schedule(
187308
inputs,
188309
weights,

0 commit comments

Comments
 (0)