Skip to content

Commit 27c5d4b

Browse files
committed
reformatted
1 parent daf9b6c commit 27c5d4b

6 files changed

Lines changed: 26 additions & 37 deletions

File tree

src/maxdiffusion/generate_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
132132
weights = lora_config.get("weight_name", [None] * len(paths))
133133
scales = lora_config.get("scale", [1.0] * len(paths))
134134
ranks = lora_config.get("rank", [64] * len(paths))
135-
135+
136136
for i in range(len(paths)):
137137
pipeline = lora_loader.load_lora_weights(
138138
pipeline,

src/maxdiffusion/loaders/lora_conversion_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -716,41 +716,34 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
716716
"attn1.to_k": "attn1.to_k",
717717
"attn1.to_v": "attn1.to_v",
718718
"attn1.to_out": "attn1.to_out.0",
719-
720719
# Audio Self Attention (audio_attn1)
721720
"audio_attn1.to_q": "audio_attn1.to_q",
722721
"audio_attn1.to_k": "audio_attn1.to_k",
723722
"audio_attn1.to_v": "audio_attn1.to_v",
724723
"audio_attn1.to_out": "audio_attn1.to_out.0",
725-
726724
# Audio Cross Attention (audio_attn2)
727725
"audio_attn2.to_q": "audio_attn2.to_q",
728726
"audio_attn2.to_k": "audio_attn2.to_k",
729727
"audio_attn2.to_v": "audio_attn2.to_v",
730728
"audio_attn2.to_out": "audio_attn2.to_out.0",
731-
732729
# Cross Attention (attn2)
733730
"attn2.to_q": "attn2.to_q",
734731
"attn2.to_k": "attn2.to_k",
735732
"attn2.to_v": "attn2.to_v",
736733
"attn2.to_out": "attn2.to_out.0",
737-
738734
# Audio to Video Cross Attention
739735
"audio_to_video_attn.to_q": "audio_to_video_attn.to_q",
740736
"audio_to_video_attn.to_k": "audio_to_video_attn.to_k",
741737
"audio_to_video_attn.to_v": "audio_to_video_attn.to_v",
742738
"audio_to_video_attn.to_out": "audio_to_video_attn.to_out.0",
743-
744739
# Video to Audio Cross Attention
745740
"video_to_audio_attn.to_q": "video_to_audio_attn.to_q",
746741
"video_to_audio_attn.to_k": "video_to_audio_attn.to_k",
747742
"video_to_audio_attn.to_v": "video_to_audio_attn.to_v",
748743
"video_to_audio_attn.to_out": "video_to_audio_attn.to_out.0",
749-
750744
# Feed Forward
751745
"ff.net_0": "ff.net.0.proj",
752746
"ff.net_2": "ff.net.2",
753-
754747
# Audio Feed Forward
755748
"audio_ff.net_0": "audio_ff.net.0.proj",
756749
"audio_ff.net_2": "audio_ff.net.2",
@@ -768,7 +761,6 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
768761
"av_cross_attn_audio_v2a_gate.linear": "diffusion_model.av_ca_v2a_gate_adaln_single.linear",
769762
"av_cross_attn_audio_scale_shift.linear": "diffusion_model.av_ca_audio_scale_shift_adaln_single.linear",
770763
"av_cross_attn_video_scale_shift.linear": "diffusion_model.av_ca_video_scale_shift_adaln_single.linear",
771-
772764
# Nested conditioning layers
773765
"time_embed.emb.timestep_embedder.linear_1": "diffusion_model.adaln_single.emb.timestep_embedder.linear_1",
774766
"time_embed.emb.timestep_embedder.linear_2": "diffusion_model.adaln_single.emb.timestep_embedder.linear_2",
@@ -786,11 +778,10 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
786778
"caption_projection.linear_2": "diffusion_model.caption_projection.linear_2",
787779
"audio_caption_projection.linear_1": "diffusion_model.audio_caption_projection.linear_1",
788780
"audio_caption_projection.linear_2": "diffusion_model.audio_caption_projection.linear_2",
789-
790781
# Connectors
791782
"feature_extractor.linear": "text_embedding_projection.aggregate_embed",
792783
}
793-
784+
794785
if nnx_path_str in global_map:
795786
return global_map[nnx_path_str]
796787

@@ -807,5 +798,3 @@ def translate_ltx2_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
807798
return f"diffusion_model.transformer_blocks.{idx}.{suffix_map[inner_suffix]}"
808799

809800
return None
810-
811-

src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def translate_fn(nnx_path_str):
6464
max_logging.log(f"Merging LoRA into connectors with rank={rank}")
6565
if h_state_dict is None and transformer_weight_name:
6666
h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs)
67-
67+
6868
if h_state_dict is not None:
6969
# Filter state dict for connector keys to avoid confusing warnings
7070
connector_state_dict = {k: v for k, v in h_state_dict.items() if k.startswith("text_embedding_projection")}

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def __call__(
455455
query = self.to_q(hidden_states)
456456
key = self.to_k(context)
457457
value = self.to_v(context)
458-
458+
459459
with jax.named_scope("QKV Norm"):
460460
query = self.norm_q(query)
461461
key = self.norm_k(key)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,24 +1026,24 @@ def scan_fn(carry, block):
10261026
)(carry, self.transformer_blocks)
10271027
else:
10281028
for block in self.transformer_blocks:
1029-
hidden_states, audio_hidden_states = block(
1030-
hidden_states=hidden_states,
1031-
audio_hidden_states=audio_hidden_states,
1032-
encoder_hidden_states=encoder_hidden_states,
1033-
audio_encoder_hidden_states=audio_encoder_hidden_states,
1034-
temb=temb,
1035-
temb_audio=temb_audio,
1036-
temb_ca_scale_shift=video_cross_attn_scale_shift,
1037-
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1038-
temb_ca_gate=video_cross_attn_a2v_gate,
1039-
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1040-
video_rotary_emb=video_rotary_emb,
1041-
audio_rotary_emb=audio_rotary_emb,
1042-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1043-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1044-
encoder_attention_mask=encoder_attention_mask,
1045-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1046-
)
1029+
hidden_states, audio_hidden_states = block(
1030+
hidden_states=hidden_states,
1031+
audio_hidden_states=audio_hidden_states,
1032+
encoder_hidden_states=encoder_hidden_states,
1033+
audio_encoder_hidden_states=audio_encoder_hidden_states,
1034+
temb=temb,
1035+
temb_audio=temb_audio,
1036+
temb_ca_scale_shift=video_cross_attn_scale_shift,
1037+
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1038+
temb_ca_gate=video_cross_attn_a2v_gate,
1039+
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1040+
video_rotary_emb=video_rotary_emb,
1041+
audio_rotary_emb=audio_rotary_emb,
1042+
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1043+
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1044+
encoder_attention_mask=encoder_attention_mask,
1045+
audio_encoder_attention_mask=audio_encoder_attention_mask,
1046+
)
10471047

10481048
# 6. Output layers
10491049
with jax.named_scope("Output Projection & Norm"):

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,10 +1233,11 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12331233
)
12341234

12351235
import time
1236+
12361237
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
12371238
for i, t_val in enumerate(timesteps):
12381239
t = timesteps_jax[i]
1239-
1240+
12401241
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
12411242
latents_jax_sharded = latents_jax
12421243
audio_latents_jax_sharded = audio_latents_jax
@@ -1340,12 +1341,11 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13401341
mesh = latents.sharding.mesh
13411342
replicated_sharding = NamedSharding(mesh, P())
13421343
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
1343-
1344+
13441345
# Replicate VAE weights
13451346
graphdef, state = nnx.split(self.vae)
13461347
state = jax.tree_util.tree_map(
1347-
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x,
1348-
state
1348+
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
13491349
)
13501350
self.vae = nnx.merge(graphdef, state)
13511351
except Exception as e:

0 commit comments

Comments
 (0)