Skip to content

Commit e0ef7ed

Browse files
committed
use_cross_timestep change in transformer
1 parent 632d2c7 commit e0ef7ed

3 files changed

Lines changed: 41 additions & 11 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,13 @@ def load_vocoder_weights_2_3(
369369
flax_state_dict = {}
370370
cpu = jax.local_devices(backend="cpu")[0]
371371

372+
from flax.traverse_util import flatten_dict
373+
flat_eval = flatten_dict(eval_shapes)
374+
print("Expected vocoder keys:", [k for k in flat_eval.keys() if "mel_stft" in str(k)])
375+
372376
for pt_key, tensor in tensors.items():
373377
# Keys are already filtered and stripped of "vocoder." by load_and_segregate
378+
print("Processing pt_key:", pt_key)
374379
key = rename_for_ltx2_3_vocoder(pt_key)
375380

376381
# Always apply LTX-2.3 specific replacement

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,7 @@ def __call__(
10271027
video_coords: Optional[jax.Array] = None,
10281028
audio_coords: Optional[jax.Array] = None,
10291029
attention_kwargs: Optional[Dict[str, Any]] = None,
1030+
use_cross_timestep: bool = False,
10301031
return_dict: bool = True,
10311032
perturbation_mask: Optional[jax.Array] = None,
10321033
) -> Any:
@@ -1096,12 +1097,20 @@ def __call__(
10961097
temb_prompt = None
10971098
temb_prompt_audio = None
10981099

1100+
if use_cross_timestep:
1101+
assert sigma is not None and audio_sigma is not None, "sigma and audio_sigma must be provided when use_cross_timestep is True"
1102+
video_ca_timestep = audio_sigma.flatten()
1103+
audio_ca_timestep = sigma.flatten()
1104+
else:
1105+
video_ca_timestep = timestep.flatten()
1106+
audio_ca_timestep = audio_timestep.flatten() if audio_timestep is not None else timestep.flatten()
1107+
10991108
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
1100-
timestep.flatten(),
1109+
video_ca_timestep,
11011110
hidden_dtype=hidden_states.dtype,
11021111
)
11031112
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
1104-
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
1113+
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
11051114
hidden_dtype=hidden_states.dtype,
11061115
)
11071116
video_cross_attn_scale_shift = video_cross_attn_scale_shift.reshape(
@@ -1110,11 +1119,11 @@ def __call__(
11101119
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.reshape(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
11111120

11121121
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
1113-
audio_timestep.flatten(),
1122+
audio_ca_timestep,
11141123
hidden_dtype=audio_hidden_states.dtype,
11151124
)
11161125
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
1117-
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
1126+
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
11181127
hidden_dtype=audio_hidden_states.dtype,
11191128
)
11201129
audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.reshape(

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
125125
"audio_attention_head_dim": 64,
126126
"audio_cross_attention_dim": 2048,
127127
"num_layers": 48,
128-
"caption_channels": 4096,
128+
"caption_channels": 3840,
129129
"audio_caption_channels": 2048,
130130
"use_prompt_embeddings": False,
131131
}
@@ -365,7 +365,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
365365
{
366366
"video_connector_num_layers": 8,
367367
"audio_connector_num_layers": 8,
368-
"caption_channels": 2048,
368+
"caption_channels": 3840,
369369
"video_caption_channels": 4096,
370370
"audio_caption_channels": 2048,
371371
"video_connector_num_attention_heads": 32,
@@ -1264,6 +1264,11 @@ def __call__(
12641264
timesteps: List[int] = None,
12651265
guidance_scale: float = 3.0,
12661266
guidance_rescale: float = 0.0,
1267+
stg_scale: float = 0.0,
1268+
modality_scale: float = 1.0,
1269+
audio_guidance_scale: Optional[float] = None,
1270+
audio_stg_scale: Optional[float] = None,
1271+
audio_modality_scale: Optional[float] = None,
12671272
noise_scale: float = 1.0,
12681273
num_videos_per_prompt: Optional[int] = 1,
12691274
generator: Optional[jax.Array] = None,
@@ -1279,6 +1284,7 @@ def __call__(
12791284
dtype: Optional[jnp.dtype] = None,
12801285
output_type: str = "pil",
12811286
return_dict: bool = True,
1287+
use_cross_timestep: bool = False,
12821288
):
12831289
# 1. Check inputs
12841290
self.check_inputs(
@@ -1499,23 +1505,24 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
14991505
audio_num_frames,
15001506
frame_rate,
15011507
perturbation_mask=perturbation_mask,
1508+
use_cross_timestep=use_cross_timestep,
15021509
)
15031510

1504-
do_stg = getattr(self.config, "stg_scale", 0.0) > 0.0
1511+
do_stg = stg_scale > 0.0
15051512

15061513
if guidance_scale > 1.0 and do_stg:
15071514
noise_pred_uncond, noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 3, axis=0)
15081515
noise_pred = (
15091516
noise_pred_uncond
15101517
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
1511-
+ self.config.stg_scale * (noise_pred_text - noise_pred_perturb)
1518+
+ stg_scale * (noise_pred_text - noise_pred_perturb)
15121519
)
15131520
# Audio guidance
15141521
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 3, axis=0)
15151522
noise_pred_audio = (
15161523
noise_pred_audio_uncond
15171524
+ guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
1518-
+ self.config.stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
1525+
+ stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
15191526
)
15201527
elif guidance_scale > 1.0:
15211528
noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
@@ -1525,10 +1532,10 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
15251532
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond)
15261533
elif do_stg:
15271534
noise_pred_text, noise_pred_perturb = jnp.split(noise_pred, 2, axis=0)
1528-
noise_pred = noise_pred_text + self.config.stg_scale * (noise_pred_text - noise_pred_perturb)
1535+
noise_pred = noise_pred_text + stg_scale * (noise_pred_text - noise_pred_perturb)
15291536

15301537
noise_pred_audio_text, noise_pred_audio_perturb = jnp.split(noise_pred_audio, 2, axis=0)
1531-
noise_pred_audio = noise_pred_audio_text + self.config.stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
1538+
noise_pred_audio = noise_pred_audio_text + stg_scale * (noise_pred_audio_text - noise_pred_audio_perturb)
15321539

15331540
# Extract latents_step based on stacking strategy
15341541
if do_cfg and do_stg:
@@ -1693,6 +1700,8 @@ def transformer_forward_pass(
16931700
fps,
16941701
perturbation_mask=None,
16951702
sigma=None,
1703+
audio_sigma=None,
1704+
use_cross_timestep=False,
16961705
):
16971706
transformer = nnx.merge(graphdef, state)
16981707

@@ -1704,11 +1713,17 @@ def transformer_forward_pass(
17041713
else:
17051714
sigma = jnp.expand_dims(sigma, 0).repeat(latents.shape[0])
17061715

1716+
if audio_sigma is None:
1717+
audio_sigma = timestep
1718+
else:
1719+
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
1720+
17071721
noise_pred, noise_pred_audio = transformer(
17081722
hidden_states=latents,
17091723
encoder_hidden_states=encoder_hidden_states,
17101724
timestep=timestep,
17111725
sigma=sigma,
1726+
audio_sigma=audio_sigma,
17121727
encoder_attention_mask=encoder_attention_mask,
17131728
num_frames=latent_num_frames,
17141729
height=latent_height,
@@ -1720,6 +1735,7 @@ def transformer_forward_pass(
17201735
audio_num_frames=audio_num_frames,
17211736
return_dict=False,
17221737
perturbation_mask=perturbation_mask,
1738+
use_cross_timestep=use_cross_timestep,
17231739
)
17241740

17251741
return noise_pred, noise_pred_audio

0 commit comments

Comments
 (0)