Skip to content

Commit 91f29c4

Browse files
committed
sharding k,v,q across context by fix in attwntion_ltx2.py
1 parent bb65834 commit 91f29c4

1 file changed

Lines changed: 19 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Optional, Tuple
1818
from flax import nnx
19+
from flax import linen as nn
1920
import jax
2021
import jax.numpy as jnp
2122
from ... import common_types
@@ -485,8 +486,24 @@ def __call__(
485486
key = apply_rotary_emb(key, rotary_emb)
486487

487488
with jax.named_scope("Attention and Output Project"):
488-
# 4. Attention
489-
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
489+
# Reshape to 4D [B, H, S, D] before passing to avoid All-Gather during transpose
490+
b, s, _ = query.shape
491+
h = self.heads
492+
d = self.dim_head
493+
494+
query = query.reshape(b, s, h, d).transpose(0, 2, 1, 3)
495+
key = key.reshape(b, s, h, d).transpose(0, 2, 1, 3)
496+
value = value.reshape(b, s, h, d).transpose(0, 2, 1, 3)
497+
498+
# Apply explicit sharding constraints on the 4D tensors
499+
q_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV))
500+
kv_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_KV_LENGTH, common_types.D_KV))
501+
502+
query = jax.lax.with_sharding_constraint(query, q_axis_names)
503+
key = jax.lax.with_sharding_constraint(key, kv_axis_names)
504+
value = jax.lax.with_sharding_constraint(value, kv_axis_names)
505+
506+
# 4. Attention (passing 4D tensors now)
490507
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
491508

492509
# 7. Output Projection

0 commit comments

Comments
 (0)