Skip to content

Commit ba15412

Browse files
committed
adding support for rope
1 parent e8734f4 commit ba15412

3 files changed

Lines changed: 156 additions & 56 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 122 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,85 @@ def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
5454
return out.astype(x.dtype)
5555

5656

57+
def apply_split_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
58+
"""
59+
Applies Split RoPE to input x.
60+
Logic matches Diffusers apply_split_rotary_emb.
61+
62+
Args:
63+
x: Input tensor.
64+
If ndim=3 [B, S, D], it will be reshaped to satisfy cos/sin shapes if needed.
65+
freqs: Tuple of (cos, sin).
66+
Expected to be [B, H, S, D//2] if coming from LTX2RotaryPosEmbed(split).
67+
"""
68+
cos, sin = freqs
69+
70+
x_dtype = x.dtype
71+
needed_reshape = False
72+
original_shape = x.shape
73+
74+
# Check if we need to reshape x to match cos layout (B, H, S, D//2)
75+
# x typically [B, S, H*D] or [B, S, D]
76+
# cos typically [B, H, S, D//2]
77+
78+
if x.ndim != 4 and cos.ndim == 4:
79+
# x is [B, S, Dim]
80+
# cos is [B, H, S, R]
81+
b = x.shape[0]
82+
h, s, r = cos.shape[1], cos.shape[2], cos.shape[3]
83+
84+
# Verify dimensions roughly match
85+
# D (dim per head) = R * 2
86+
# Dim = H * D = H * 2 * R
87+
88+
# reshape x to [B, S, H, 2*R] -> transpose to [B, H, S, 2*R]
89+
x = x.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
90+
needed_reshape = True
91+
92+
# Now x should be [..., 2*R] i.e. [B, H, S, 2*R] considering the logic below
93+
94+
last_dim = x.shape[-1]
95+
r = last_dim // 2
96+
97+
# Reshape last dim to (2, r)
98+
# [..., 2*R] -> [..., 2, R]
99+
split_x = x.reshape(*x.shape[:-1], 2, r)
100+
101+
# Split into first and second half
102+
first_x = split_x[..., 0, :] # [..., R]
103+
second_x = split_x[..., 1, :] # [..., R]
104+
105+
# Broadcast cos/sin: [B, H, S, R] -> [B, H, S, 1, R]
106+
cos_u = jnp.expand_dims(cos, axis=-2)
107+
sin_u = jnp.expand_dims(sin, axis=-2)
108+
109+
# out = split_x * cos_u
110+
# This applies cos to both halves
111+
out = split_x * cos_u
112+
113+
# Modifications
114+
# first_out = x1*cos - x2*sin
115+
# second_out = x2*cos + x1*sin
116+
117+
# Apply updates
118+
# We construct result manually to avoid in-place ops
119+
out_first = out[..., 0, :] - second_x * sin_u.squeeze(-2)
120+
out_second = out[..., 1, :] + first_x * sin_u.squeeze(-2)
121+
122+
# Stack back: [..., 2, R]
123+
out = jnp.stack([out_first, out_second], axis=-2)
124+
125+
# Flatten back last dim: [..., 2*R]
126+
out = out.reshape(*out.shape[:-2], last_dim)
127+
128+
if needed_reshape:
129+
# [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
130+
out = out.transpose(0, 2, 1, 3).reshape(original_shape)
131+
132+
return out.astype(x_dtype)
133+
134+
135+
57136
class LTX2RotaryPosEmbed(nnx.Module):
58137
"""
59138
Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.
@@ -131,26 +210,13 @@ def prepare_video_coords(
131210
latent_coords = jnp.expand_dims(latent_coords, 0) # [1, num_patches, 3, 2]
132211
latent_coords = jnp.tile(latent_coords, (batch_size, 1, 1, 1)) # [B, num_patches, 3, 2]
133212

134-
# Transpose to match desired shape [B, 3, num_patches, 2] if needed,
135-
# BUT Diffusers returns [B, 3, num_patches, 2] from flatten(1,3) on [3, N_F, N_H, N_W, 2]??
136-
# Diffusers:
137-
# latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2]
138-
# latent_coords = latent_coords.flatten(1, 3) # [3, num_patches, 2]
139-
# latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) # [B, 3, num_patches, 2]
140-
# My JAX above:
141-
# latent_coords = latent_coords.reshape(-1, 3, 2) was wrong relative to Diffusers shape
142-
143-
# Correct JAX implementation matching Diffusers:
144213
latent_coords = jnp.stack([grid, patch_ends], axis=-1) # [3, N_F, N_H, N_W, 2]
145214
latent_coords = latent_coords.reshape(3, -1, 2) # [3, num_patches, 2]
146215
latent_coords = jnp.expand_dims(latent_coords, 0) # [1, 3, num_patches, 2]
147216
latent_coords = jnp.tile(latent_coords, (batch_size, 1, 1, 1)) # [B, 3, num_patches, 2]
148217

149218
# 3. Calculate pixel space coords
150219
scale_tensor = jnp.array(self.scale_factors, dtype=latent_coords.dtype)
151-
# Broadcast scale factors: [1, 3, 1, 1] matches [B, 3, num, 2] logic?
152-
# Diffusers: broadcast_shape[1] = -1 (frame, height, width dim)
153-
# Actually scale_factors is (8, 32, 32) corresponding to (F, H, W) i.e. dim 1 of latent_coords
154220
scale_tensor = scale_tensor.reshape(1, 3, 1, 1)
155221
pixel_coords = latent_coords * scale_tensor
156222

@@ -260,11 +326,6 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
260326

261327
# Padding if needed
262328
if self.dim % num_rope_elems != 0:
263-
# Diffusers logic: pad with ones/zeros
264-
# JAX requires careful padding
265-
# But here we computed freqs for `steps = self.dim // num_rope_elems`
266-
# So we have `steps * num_rope_elems` elements currently.
267-
# If mismatch, we pad.
268329
curr_dim = cos_freqs.shape[-1]
269330
pad_amt = self.dim - curr_dim
270331
if pad_amt > 0:
@@ -274,42 +335,32 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
274335
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)
275336

276337
elif self.rope_type == "split":
277-
# [B, N, D//2] -> [B, N, D//2]
278-
# Padding first? Diffusers:
279-
# expected_freqs = self.dim // 2
280-
# current_freqs = freqs.shape[-1]
281-
# pad_size = expected_freqs - current_freqs
282-
# if pad != 0: pad (before logic?)
283-
# Actually Diffusers code:
284-
# cos_freq = freqs.cos(), sin_freq = freqs.sin()
285-
# if pad: concatenate([pad, cos_freq], axis=-1)
286-
# THEN reshape to multi-head?
287-
288-
curr_dim = cos_freqs.shape[-1]
338+
# Cos/Sin
339+
cos_freq = jnp.cos(freqs)
340+
sin_freq = jnp.sin(freqs)
341+
342+
curr_dim = cos_freq.shape[-1]
289343
expected_dim = self.dim // 2
290344
pad_size = expected_dim - curr_dim
291345

292346
if pad_size > 0:
293-
cos_padding = jnp.ones((*cos_freqs.shape[:-1], pad_size), dtype=cos_freqs.dtype)
294-
sin_padding = jnp.zeros((*sin_freqs.shape[:-1], pad_size), dtype=sin_freqs.dtype)
295-
cos_freqs = jnp.concatenate([cos_padding, cos_freqs], axis=-1)
296-
sin_freqs = jnp.concatenate([sin_padding, sin_freqs], axis=-1)
297-
298-
# Reshape for multi-head?
299-
# Diffusers:
300-
# cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
301-
# swapaxes(1, 2) -> (B, H, T, D//2)
302-
# Here: input `coords` was flattened tokens (N).
303-
# We assume N = Time?
304-
# Wait, `prepare_video_coords` flattens all patches (T*H*W).
305-
# So N = T*H*W.
306-
# If `rope_type="split"`, does it imply specific Time-Head structure?
307-
# LTX-2 `transformer_ltx2.py` in Diffusers passes `rope_type="interleaved"` by default.
308-
# Split is mostly for specific attention optimizations.
309-
# I will skip the complex reshape logic for now unless requested,
310-
# as standard flow is interleaved.
311-
# But I should keep the frequency generation logic consistent.
312-
pass
347+
cos_padding = jnp.ones((*cos_freq.shape[:-1], pad_size), dtype=cos_freq.dtype)
348+
sin_padding = jnp.zeros((*sin_freq.shape[:-1], pad_size), dtype=sin_freq.dtype)
349+
cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1)
350+
sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1)
351+
352+
# Reshape freqs to be compatible with multi-head attention
353+
# Diffusers: cos_freq.reshape(b, t, self.num_attention_heads, -1) -> swapaxes(1, 2)
354+
# [B, S, D//2] -> [B, S, H, dim_head//2] -> [B, H, S, dim_head//2]
355+
356+
b = cos_freq.shape[0]
357+
s = cos_freq.shape[1]
358+
359+
# We need to know H. `LTX2RotaryPosEmbed` has `num_attention_heads`.
360+
h = self.num_attention_heads
361+
362+
cos_freqs = cos_freq.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
363+
sin_freqs = sin_freq.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
313364

314365
return cos_freqs, sin_freqs
315366

@@ -330,8 +381,10 @@ def __init__(
330381
eps: float = 1e-6,
331382
dtype: DType = jnp.float32,
332383
attention_kernel: str = "flash",
384+
rope_type: str = "interleaved",
333385
):
334386
self.heads = heads
387+
self.rope_type = rope_type
335388
self.dim_head = dim_head
336389
self.inner_dim = dim_head * heads
337390
self.dropout_rate = dropout
@@ -387,12 +440,26 @@ def __call__(
387440

388441
# 3. Apply RoPE to tensors of shape [B, S, InnerDim]
389442
# Frequencies are shape [B, S, InnerDim]
443+
# 3. Apply RoPE
390444
if rotary_emb is not None:
391-
query = apply_rotary_emb(query, rotary_emb)
392-
if k_rotary_emb is not None:
393-
key = apply_rotary_emb(key, k_rotary_emb)
394-
elif encoder_hidden_states is None:
395-
key = apply_rotary_emb(key, rotary_emb)
445+
if hasattr(self, "rope_type") and self.rope_type == "split":
446+
# Split RoPE: passing full freqs [B, H, S, D//2]
447+
# apply_split_rotary_emb handles reshaping query/key
448+
449+
query = apply_split_rotary_emb(query, rotary_emb)
450+
451+
if k_rotary_emb is not None:
452+
key = apply_split_rotary_emb(key, k_rotary_emb)
453+
elif encoder_hidden_states is None:
454+
key = apply_split_rotary_emb(key, rotary_emb)
455+
456+
else:
457+
# Interleaved (Default)
458+
query = apply_rotary_emb(query, rotary_emb)
459+
if k_rotary_emb is not None:
460+
key = apply_rotary_emb(key, k_rotary_emb)
461+
elif encoder_hidden_states is None:
462+
key = apply_rotary_emb(key, rotary_emb)
396463

397464
# 4. Attention
398465
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,6 @@ def __init__(
700700
num_attention_heads=self.audio_num_attention_heads,
701701
)
702702

703-
# 5. Transformer Blocks
704703
# 5. Transformer Blocks
705704
@nnx.split_rngs(splits=self.num_layers)
706705
@nnx.vmap(in_axes=0, out_axes=0, axis_size=self.num_layers, transform_metadata={nnx.PARTITION_NAME: "layers"})

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,40 @@ def test_ltx2_rope(self):
104104
self.assertEqual(cos.shape, (1, 10, dim))
105105
self.assertEqual(sin.shape, (1, 10, dim))
106106

107+
def test_ltx2_rope_split(self):
108+
"""Tests LTX2RotaryPosEmbed with rope_type='split'."""
109+
dim = self.dim
110+
patch_size = self.patch_size
111+
patch_size_t = self.patch_size_t
112+
base_num_frames = 8
113+
base_height = 32
114+
base_width = 32
115+
116+
# Video RoPE Split
117+
rope = LTX2RotaryPosEmbed(
118+
dim=dim,
119+
patch_size=patch_size,
120+
patch_size_t=patch_size_t,
121+
base_num_frames=base_num_frames,
122+
base_height=base_height,
123+
base_width=base_width,
124+
modality="video",
125+
rope_type="split"
126+
)
127+
ids = jnp.ones((1, 3, 10)) # (B, Axes, S)
128+
cos, sin = rope(ids)
129+
130+
# Check output shape
131+
# Split RoPE returns concatenated [cos, cos] to match dim
132+
self.assertEqual(cos.shape, (1, 10, dim))
133+
self.assertEqual(sin.shape, (1, 10, dim))
134+
135+
# Verify values are concatenated
136+
cos1, cos2 = jnp.split(cos, 2, axis=-1)
137+
# They should be identical
138+
self.assertTrue(jnp.allclose(cos1, cos2))
139+
140+
107141
def test_ltx2_ada_layer_norm_single(self):
108142
"""Tests LTX2AdaLayerNormSingle initialization and execution."""
109143
key = jax.random.key(0)

0 commit comments

Comments
 (0)