Skip to content

Commit 3f2320c

Browse files
committed
changes to attention, video vae and weight loading
1 parent d4950fb commit 3f2320c

3 files changed

Lines changed: 57 additions & 2 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __init__(
353353
qkv_sharding_spec: Optional[tuple] = None,
354354
out_sharding_spec: Optional[tuple] = None,
355355
out_bias_sharding_spec: Optional[tuple] = None,
356+
gated_attn: bool = False,
356357
):
357358
self.heads = heads
358359
self.rope_type = rope_type
@@ -444,6 +445,17 @@ def __init__(
444445
else:
445446
self.dropout_layer = None
446447

448+
if gated_attn:
449+
self.to_gate_logits = nnx.Linear(
450+
query_dim,
451+
heads,
452+
use_bias=True,
453+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")),
454+
bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)),
455+
rngs=rngs,
456+
dtype=dtype,
457+
)
458+
447459
self.attention_op = NNXAttentionOp(
448460
mesh=mesh,
449461
attention_kernel=attention_kernel,
@@ -464,6 +476,7 @@ def __call__(
464476
attention_mask: Optional[Array] = None,
465477
rotary_emb: Optional[Tuple[Array, Array]] = None,
466478
k_rotary_emb: Optional[Tuple[Array, Array]] = None,
479+
perturbation_mask: Optional[Array] = None,
467480
) -> Array:
468481
# Determine context (Self or Cross)
469482
context = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
@@ -507,6 +520,17 @@ def __call__(
507520
# NNXAttentionOp expects flattened input [B, S, InnerDim] for flash kernel
508521
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
509522

523+
if perturbation_mask is not None:
524+
attn_output = value + perturbation_mask * (attn_output - value)
525+
526+
if getattr(self, "to_gate_logits", None) is not None:
527+
gate_logits = self.to_gate_logits(hidden_states)
528+
b, s, _ = attn_output.shape
529+
attn_output = attn_output.reshape(b, s, self.heads, self.dim_head)
530+
gates = 2.0 * jax.nn.sigmoid(gate_logits)
531+
attn_output = attn_output * jnp.expand_dims(gates, axis=-1)
532+
attn_output = attn_output.reshape(b, s, -1)
533+
510534
# 7. Output Projection
511535
hidden_states = self.to_out(attn_output)
512536

src/maxdiffusion/models/ltx2/autoencoder_kl_ltx2.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ def __init__(
668668
timestep_conditioning: bool = False,
669669
upsample_residual: bool = False,
670670
upscale_factor: int = 1,
671+
upsample_type: str = "spatiotemporal",
671672
spatial_padding_mode: str = "constant",
672673
rngs: Optional[nnx.Rngs] = None,
673674
mesh: Optional[jax.sharding.Mesh] = None,
@@ -711,9 +712,18 @@ def __init__(
711712
)
712713

713714
if spatio_temporal_scale:
715+
if upsample_type == "spatiotemporal":
716+
stride = (2, 2, 2)
717+
elif upsample_type == "temporal":
718+
stride = (2, 1, 1)
719+
elif upsample_type == "spatial":
720+
stride = (1, 2, 2)
721+
else:
722+
raise ValueError(f"Unknown upsample_type: {upsample_type}")
723+
714724
self.upsampler = LTXVideoUpsampler3d(
715725
in_channels=out_channels * upscale_factor,
716-
stride=(2, 2, 2),
726+
stride=stride,
717727
residual=upsample_residual,
718728
upscale_factor=upscale_factor,
719729
spatial_padding_mode=spatial_padding_mode,
@@ -954,6 +964,7 @@ def __init__(
954964
timestep_conditioning: bool = False,
955965
upsample_residual: Tuple[bool, ...] = (True, True, True),
956966
upsample_factor: Tuple[int, ...] = (2, 2, 2),
967+
upsample_type: Tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
957968
spatial_padding_mode: str = "reflect",
958969
rngs: Optional[nnx.Rngs] = None,
959970
mesh: Optional[jax.sharding.Mesh] = None,
@@ -1020,6 +1031,7 @@ def __init__(
10201031
timestep_conditioning=timestep_conditioning,
10211032
upsample_residual=upsample_residual[i],
10221033
upscale_factor=upsample_factor[i],
1034+
upsample_type=upsample_type[i],
10231035
spatial_padding_mode=spatial_padding_mode,
10241036
rngs=rngs,
10251037
mesh=mesh,
@@ -1139,6 +1151,7 @@ def __init__(
11391151
downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
11401152
upsample_residual: Tuple[bool, ...] = (True, True, True),
11411153
upsample_factor: Tuple[int, ...] = (2, 2, 2),
1154+
upsample_type: Tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
11421155
timestep_conditioning: bool = False,
11431156
patch_size: int = 4,
11441157
patch_size_t: int = 1,
@@ -1184,6 +1197,7 @@ def __init__(
11841197
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
11851198
upsample_factor=upsample_factor,
11861199
upsample_residual=upsample_residual,
1200+
upsample_type=upsample_type,
11871201
patch_size=patch_size,
11881202
patch_size_t=patch_size_t,
11891203
resnet_norm_eps=resnet_norm_eps,

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def rename_for_ltx2_transformer(key):
4040
"""
4141
Renames Diffusers LTX-2 keys to MaxDiffusion Flax LTX-2 keys.
4242
"""
43+
if "caption_proj" in key and "caption_projection" not in key:
44+
key = key.replace("caption_proj", "caption_projection")
45+
if "audio_caption_proj" in key and "audio_caption_projection" not in key:
46+
key = key.replace("audio_caption_proj", "audio_caption_projection")
47+
4348
key = key.replace("patchify_proj", "proj_in")
4449
key = key.replace("audio_patchify_proj", "audio_proj_in")
4550
key = key.replace("norm_final", "norm_out")
@@ -289,11 +294,21 @@ def load_vocoder_weights(
289294

290295
flax_key = _tuple_str_to_int(parts)
291296

297+
# Skip filter keys as they are derived in NNX model
298+
if "filter" in flax_key:
299+
continue
300+
292301
if flax_key[-1] == "kernel":
293302
if "upsamplers" in flax_key:
294-
tensor = tensor.transpose(2, 0, 1)[::-1, :, :]
303+
if "2.3" in pretrained_model_name_or_path:
304+
tensor = tensor.transpose(2, 0, 1)
305+
else:
306+
tensor = tensor.transpose(2, 0, 1)[::-1, :, :]
295307
else:
296308
tensor = tensor.transpose(2, 1, 0)
309+
310+
if "mel_stft" in flax_key and ("forward_basis" in flax_key or "inverse_basis" in flax_key):
311+
tensor = tensor.transpose(2, 1, 0)
297312

298313
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
299314

@@ -305,6 +320,8 @@ def rename_for_ltx2_connector(key):
305320
key = key.replace("video_connector", "video_embeddings_connector")
306321
key = key.replace("audio_connector", "audio_embeddings_connector")
307322
key = key.replace("text_proj_in", "feature_extractor.linear")
323+
key = key.replace("audio_feature_extractor.linear", "audio_text_proj_in")
324+
key = key.replace("video_feature_extractor.linear", "video_text_proj_in")
308325

309326
if "transformer_blocks" in key:
310327
key = key.replace("transformer_blocks", "stacked_blocks")

0 commit comments

Comments
 (0)