Skip to content

Commit fd90736

Browse files
committed
Reformatted with pyink
1 parent 1059c10 commit fd90736

4 files changed

Lines changed: 59 additions & 48 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -59,49 +59,48 @@ def apply_split_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
5959
"""
6060
Applies Split RoPE to input x.
6161
Logic matches Diffusers apply_split_rotary_emb.
62-
62+
6363
Args:
64-
x: Input tensor.
64+
x: Input tensor.
6565
If ndim=3 [B, S, D], it will be reshaped to satisfy cos/sin shapes if needed.
66-
freqs: Tuple of (cos, sin).
66+
freqs: Tuple of (cos, sin).
6767
Expected to be [B, H, S, D//2] if coming from LTX2RotaryPosEmbed(split).
6868
"""
6969
cos, sin = freqs
70-
70+
7171
x_dtype = x.dtype
7272
needed_reshape = False
7373
original_shape = x.shape
74-
74+
7575
if x.ndim != 4 and cos.ndim == 4:
76-
b = x.shape[0]
77-
h, s, r = cos.shape[1], cos.shape[2], cos.shape[3]
78-
x = x.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
79-
needed_reshape = True
80-
76+
b = x.shape[0]
77+
h, s, r = cos.shape[1], cos.shape[2], cos.shape[3]
78+
x = x.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
79+
needed_reshape = True
80+
8181
last_dim = x.shape[-1]
8282
r = last_dim // 2
83-
83+
8484
split_x = x.reshape(*x.shape[:-1], 2, r)
85-
85+
8686
first_x = split_x[..., 0, :]
8787
second_x = split_x[..., 1, :]
88-
88+
8989
cos_u = jnp.expand_dims(cos, axis=-2)
9090
sin_u = jnp.expand_dims(sin, axis=-2)
91-
91+
9292
out = split_x * cos_u
93-
93+
9494
out_first = out[..., 0, :] - second_x * sin_u.squeeze(-2)
9595
out_second = out[..., 1, :] + first_x * sin_u.squeeze(-2)
96-
96+
9797
out = jnp.stack([out_first, out_second], axis=-2)
9898
out = out.reshape(*out.shape[:-2], last_dim)
99-
99+
100100
if needed_reshape:
101-
out = out.transpose(0, 2, 1, 3).reshape(original_shape)
102-
103-
return out.astype(x_dtype)
101+
out = out.transpose(0, 2, 1, 3).reshape(original_shape)
104102

103+
return out.astype(x_dtype)
105104

106105

107106
class LTX2RotaryPosEmbed(nnx.Module):
@@ -308,11 +307,10 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
308307
cos_freq = jnp.concatenate([cos_padding, cos_freq], axis=-1)
309308
sin_freq = jnp.concatenate([sin_padding, sin_freq], axis=-1)
310309

311-
312310
b = cos_freq.shape[0]
313311
s = cos_freq.shape[1]
314312
h = self.num_attention_heads
315-
313+
316314
cos_freqs = cos_freq.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
317315
sin_freqs = sin_freq.reshape(b, s, h, -1).transpose(0, 2, 1, 3)
318316

@@ -352,8 +350,12 @@ def __init__(
352350
self.to_v = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
353351

354352
# 2. Normalization (Applied to full inner_dim, NOT per-head)
355-
self.norm_q = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
356-
self.norm_k = nnx.RMSNorm(self.inner_dim, epsilon=eps, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
353+
self.norm_q = nnx.RMSNorm(
354+
self.inner_dim, epsilon=eps, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs
355+
)
356+
self.norm_k = nnx.RMSNorm(
357+
self.inner_dim, epsilon=eps, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs
358+
)
357359

358360
# 3. Output
359361
self.to_out = nnx.Linear(self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype)
@@ -397,23 +399,23 @@ def __call__(
397399
# 3. Apply RoPE
398400
if rotary_emb is not None:
399401
if hasattr(self, "rope_type") and self.rope_type == "split":
400-
# Split RoPE: passing full freqs [B, H, S, D//2]
401-
# apply_split_rotary_emb handles reshaping query/key
402-
403-
query = apply_split_rotary_emb(query, rotary_emb)
404-
405-
if k_rotary_emb is not None:
406-
key = apply_split_rotary_emb(key, k_rotary_emb)
407-
elif encoder_hidden_states is None:
408-
key = apply_split_rotary_emb(key, rotary_emb)
409-
402+
# Split RoPE: passing full freqs [B, H, S, D//2]
403+
# apply_split_rotary_emb handles reshaping query/key
404+
405+
query = apply_split_rotary_emb(query, rotary_emb)
406+
407+
if k_rotary_emb is not None:
408+
key = apply_split_rotary_emb(key, k_rotary_emb)
409+
elif encoder_hidden_states is None:
410+
key = apply_split_rotary_emb(key, rotary_emb)
411+
410412
else:
411-
# Interleaved (Default)
412-
query = apply_rotary_emb(query, rotary_emb)
413-
if k_rotary_emb is not None:
414-
key = apply_rotary_emb(key, k_rotary_emb)
415-
elif encoder_hidden_states is None:
416-
key = apply_rotary_emb(key, rotary_emb)
413+
# Interleaved (Default)
414+
query = apply_rotary_emb(query, rotary_emb)
415+
if k_rotary_emb is not None:
416+
key = apply_rotary_emb(key, k_rotary_emb)
417+
elif encoder_hidden_states is None:
418+
key = apply_rotary_emb(key, rotary_emb)
417419

418420
# 4. Attention
419421
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def __init__(
208208

209209
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
210210
self.audio_to_video_norm = nnx.RMSNorm(
211-
dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
211+
dim,
212+
epsilon=self.norm_eps,
213+
use_scale=self.norm_elementwise_affine,
214+
rngs=rngs,
215+
dtype=jnp.float32,
216+
param_dtype=jnp.float32,
212217
)
213218
self.audio_to_video_attn = LTX2Attention(
214219
rngs=rngs,
@@ -252,7 +257,12 @@ def __init__(
252257

253258
# 4. Feed Forward
254259
self.norm3 = nnx.RMSNorm(
255-
dim, epsilon=self.norm_eps, use_scale=self.norm_elementwise_affine, rngs=rngs, dtype=jnp.float32, param_dtype=jnp.float32
260+
dim,
261+
epsilon=self.norm_eps,
262+
use_scale=self.norm_elementwise_affine,
263+
rngs=rngs,
264+
dtype=jnp.float32,
265+
param_dtype=jnp.float32,
256266
)
257267
self.ff = NNXSimpleFeedForward(
258268
rngs=rngs,

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def test_import_parity_comparison(self):
419419
num_layers=1,
420420
mesh=self.mesh,
421421
attention_kernel="dot_product",
422-
rope_type="interleaved"
422+
rope_type="interleaved",
423423
)
424424

425425
# 2. Convert Weights (PyTorch -> Flax NNX)

src/maxdiffusion/tests/ltx2_transformer_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_ltx2_rope_split(self):
112112
base_num_frames = 8
113113
base_height = 32
114114
base_width = 32
115-
115+
116116
# Video RoPE Split
117117
rope = LTX2RotaryPosEmbed(
118118
dim=dim,
@@ -122,18 +122,17 @@ def test_ltx2_rope_split(self):
122122
base_height=base_height,
123123
base_width=base_width,
124124
modality="video",
125-
rope_type="split"
125+
rope_type="split",
126126
)
127-
ids = jnp.ones((1, 3, 10)) # (B, Axes, S)
127+
ids = jnp.ones((1, 3, 10)) # (B, Axes, S)
128128
cos, sin = rope(ids)
129-
129+
130130
# Check output shape
131131
# Split RoPE returns [B, H, S, D//2]
132132
# dim=1024, heads=32 => head_dim=32 => D//2 = 16
133133
self.assertEqual(cos.shape, (1, 32, 10, 16))
134134
self.assertEqual(sin.shape, (1, 32, 10, 16))
135135

136-
137136
def test_ltx2_ada_layer_norm_single(self):
138137
"""Tests LTX2AdaLayerNormSingle initialization and execution."""
139138
key = jax.random.key(0)

0 commit comments

Comments
 (0)