Skip to content

Commit 2f447f1

Browse files
committed
test
1 parent de14eec commit 2f447f1

1 file changed

Lines changed: 13 additions & 42 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
3131
Logic matches LTX-2 PyTorch: pairs neighbors [-x2, x1].
3232
3333
Args:
34-
x: Input tensor [B, S, H, D]
35-
freqs: Tuple of (cos, sin), broadcasting to [B, S, 1, D] or [B, S, H, D]
34+
x: Input tensor [..., D]
35+
freqs: Tuple of (cos, sin), broadcasting to [..., D]
3636
"""
3737
cos, sin = freqs
3838

@@ -71,7 +71,7 @@ def __call__(self, ids: Array) -> Tuple[Array, Array]:
7171
- For Video 3D: Num_Axes=3 (Time, Height, Width)
7272
- For Audio 1D: Num_Axes=1 (Time)
7373
Returns:
74-
cos, sin: [B, S, 1, Dim] (Ready for broadcasting across heads)
74+
cos, sin: [B, S, Dim]
7575
"""
7676
num_axes = ids.shape[-1]
7777
dim_per_axis = self.dim // num_axes
@@ -97,8 +97,7 @@ def __call__(self, ids: Array) -> Tuple[Array, Array]:
9797
cos = jnp.repeat(cos, 2, axis=-1)
9898
sin = jnp.repeat(sin, 2, axis=-1)
9999

100-
# Add head dim for broadcasting: [B, S, 1, Inner_Dim]
101-
return cos[:, :, None, :], sin[:, :, None, :]
100+
return cos, sin
102101

103102

104103
class LTX2Attention(nnx.Module):
@@ -151,18 +150,6 @@ def __init__(
151150
dtype=dtype,
152151
)
153152

154-
def _reshape_rope(self, rope_emb: Tuple[Array, Array]) -> Tuple[Array, Array]:
155-
"""Reshapes [B, S, 1, InnerDim] -> [B, S, Heads, DimHead] for broadcasting."""
156-
cos, sin = rope_emb
157-
# If tests pass already shaped tensors, return as is
158-
if cos.ndim == 4 and cos.shape[-2] == self.heads and cos.shape[-1] == self.dim_head:
159-
return cos, sin
160-
161-
# Reshape: [B, S, 1, H*D] -> [B, S, H, D]
162-
# We assume the last dimension is InnerDim = Heads * DimHead
163-
new_shape = cos.shape[:-2] + (self.heads, self.dim_head)
164-
return cos.reshape(new_shape), sin.reshape(new_shape)
165-
166153
def __call__(
167154
self,
168155
hidden_states: Array,
@@ -184,33 +171,17 @@ def __call__(
184171
query = self.norm_q(query)
185172
key = self.norm_k(key)
186173

187-
# 3. Reshape to Heads [B, S, H, D]
188-
query = query.reshape(*query.shape[:-1], self.heads, self.dim_head)
189-
key = key.reshape(*key.shape[:-1], self.heads, self.dim_head)
190-
value = value.reshape(*value.shape[:-1], self.heads, self.dim_head)
191-
192-
# 4. Apply RoPE
174+
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
175+
# Frequencies are shape [B, S, InnerDim]
193176
if rotary_emb is not None:
194-
# Reshape [1, Inner] -> [H, D]
195-
q_rope = self._reshape_rope(rotary_emb)
196-
query = apply_rotary_emb(query, q_rope)
197-
198-
# Key RoPE Logic
177+
query = apply_rotary_emb(query, rotary_emb)
199178
if k_rotary_emb is not None:
200-
# Explicit Key RoPE (e.g. Cross-Modal)
201-
k_rope = self._reshape_rope(k_rotary_emb)
202-
key = apply_rotary_emb(key, k_rope)
203-
elif encoder_hidden_states is None:
204-
# Self-Attention: Re-use q_rope
205-
key = apply_rotary_emb(key, q_rope)
206-
207-
# 5. Flatten back for AttentionOp [B, S, H*D]
208-
# NNXAttentionOp expects flattened input for flash kernel
209-
query = query.reshape(*query.shape[:-2], self.inner_dim)
210-
key = key.reshape(*key.shape[:-2], self.inner_dim)
211-
value = value.reshape(*value.shape[:-2], self.inner_dim)
212-
213-
# 6. Attention
179+
key = apply_rotary_emb(key, k_rotary_emb)
180+
elif encoder_hidden_states is None:
181+
key = apply_rotary_emb(key, rotary_emb)
182+
183+
# 4. Attention
184+
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
214185
attn_output = self.attention_op.apply_attention(
215186
query=query, key=key, value=value, attention_mask=attention_mask
216187
)

0 commit comments

Comments
 (0)