Skip to content

Commit 50160c6

Browse files
committed
sharding k,v,q across context by fix in attwntion_ltx2.py
1 parent 91f29c4 commit 50160c6

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,14 @@ def __call__(
487487

488488
with jax.named_scope("Attention and Output Project"):
489489
# Reshape to 4D [B, H, S, D] before passing to avoid All-Gather during transpose
490-
b, s, _ = query.shape
490+
b, s_q, _ = query.shape
491+
_, s_kv, _ = key.shape
491492
h = self.heads
492493
d = self.dim_head
493494

494-
query = query.reshape(b, s, h, d).transpose(0, 2, 1, 3)
495-
key = key.reshape(b, s, h, d).transpose(0, 2, 1, 3)
496-
value = value.reshape(b, s, h, d).transpose(0, 2, 1, 3)
495+
query = query.reshape(b, s_q, h, d).transpose(0, 2, 1, 3)
496+
key = key.reshape(b, s_kv, h, d).transpose(0, 2, 1, 3)
497+
value = value.reshape(b, s_kv, h, d).transpose(0, 2, 1, 3)
497498

498499
# Apply explicit sharding constraints on the 4D tensors
499500
q_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV))

0 commit comments

Comments
 (0)