We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 91f29c4 commit 50160c6Copy full SHA for 50160c6
1 file changed
src/maxdiffusion/models/ltx2/attention_ltx2.py
@@ -487,13 +487,14 @@ def __call__(
487
488
with jax.named_scope("Attention and Output Project"):
489
# Reshape to 4D [B, H, S, D] before passing to avoid All-Gather during transpose
490
- b, s, _ = query.shape
+ b, s_q, _ = query.shape
491
+ _, s_kv, _ = key.shape
492
h = self.heads
493
d = self.dim_head
494
- query = query.reshape(b, s, h, d).transpose(0, 2, 1, 3)
495
- key = key.reshape(b, s, h, d).transpose(0, 2, 1, 3)
496
- value = value.reshape(b, s, h, d).transpose(0, 2, 1, 3)
+ query = query.reshape(b, s_q, h, d).transpose(0, 2, 1, 3)
+ key = key.reshape(b, s_kv, h, d).transpose(0, 2, 1, 3)
497
+ value = value.reshape(b, s_kv, h, d).transpose(0, 2, 1, 3)
498
499
# Apply explicit sharding constraints on the 4D tensors
500
q_axis_names = nn.logical_to_mesh_axes((common_types.BATCH, common_types.SELF_ATTN_HEAD, common_types.SELF_ATTN_Q_LENGTH, common_types.D_KV))
0 commit comments