Skip to content

Commit 115fffa

Browse files
committed
Added sharding on ROPE
1 parent 0a7d593 commit 115fffa

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,9 +1083,18 @@ def __call__(
10831083

10841084
if rotary_emb is not None:
10851085
with self.conditional_named_scope("attn_rope"):
1086+
axis_names_rope = nn.logical_to_mesh_axes((None, None, LENGTH, None))
1087+
rotary_emb = jax.lax.with_sharding_constraint(rotary_emb, axis_names_rope)
10861088
query_proj = _unflatten_heads(query_proj, self.heads)
10871089
key_proj = _unflatten_heads(key_proj, self.heads)
10881090
value_proj = _unflatten_heads(value_proj, self.heads)
1091+
1092+
# Enforce sequence parallelism on the new axis 2 (LENGTH) before doing the ROPE math
1093+
axis_names_qkv = nn.logical_to_mesh_axes((BATCH, HEAD, LENGTH, D_KV))
1094+
query_proj = jax.lax.with_sharding_constraint(query_proj, axis_names_qkv)
1095+
key_proj = jax.lax.with_sharding_constraint(key_proj, axis_names_qkv)
1096+
value_proj = jax.lax.with_sharding_constraint(value_proj, axis_names_qkv)
1097+
10891098
# output of _unflatten_heads Batch, heads, seq_len, head_dim
10901099
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
10911100

0 commit comments

Comments
 (0)