Skip to content

Commit b595666

Browse files
committed
reformatted
1 parent 3f4cfc3 commit b595666

3 files changed

Lines changed: 23 additions & 23 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def __call__(
455455
query = self.to_q(hidden_states)
456456
key = self.to_k(context)
457457
value = self.to_v(context)
458-
458+
459459
with jax.named_scope("QKV Norm"):
460460
query = self.norm_q(query)
461461
key = self.norm_k(key)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,24 +1026,24 @@ def scan_fn(carry, block):
10261026
)(carry, self.transformer_blocks)
10271027
else:
10281028
for block in self.transformer_blocks:
1029-
hidden_states, audio_hidden_states = block(
1030-
hidden_states=hidden_states,
1031-
audio_hidden_states=audio_hidden_states,
1032-
encoder_hidden_states=encoder_hidden_states,
1033-
audio_encoder_hidden_states=audio_encoder_hidden_states,
1034-
temb=temb,
1035-
temb_audio=temb_audio,
1036-
temb_ca_scale_shift=video_cross_attn_scale_shift,
1037-
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1038-
temb_ca_gate=video_cross_attn_a2v_gate,
1039-
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1040-
video_rotary_emb=video_rotary_emb,
1041-
audio_rotary_emb=audio_rotary_emb,
1042-
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1043-
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1044-
encoder_attention_mask=encoder_attention_mask,
1045-
audio_encoder_attention_mask=audio_encoder_attention_mask,
1046-
)
1029+
hidden_states, audio_hidden_states = block(
1030+
hidden_states=hidden_states,
1031+
audio_hidden_states=audio_hidden_states,
1032+
encoder_hidden_states=encoder_hidden_states,
1033+
audio_encoder_hidden_states=audio_encoder_hidden_states,
1034+
temb=temb,
1035+
temb_audio=temb_audio,
1036+
temb_ca_scale_shift=video_cross_attn_scale_shift,
1037+
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
1038+
temb_ca_gate=video_cross_attn_a2v_gate,
1039+
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
1040+
video_rotary_emb=video_rotary_emb,
1041+
audio_rotary_emb=audio_rotary_emb,
1042+
ca_video_rotary_emb=video_cross_attn_rotary_emb,
1043+
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1044+
encoder_attention_mask=encoder_attention_mask,
1045+
audio_encoder_attention_mask=audio_encoder_attention_mask,
1046+
)
10471047

10481048
# 6. Output layers
10491049
with jax.named_scope("Output Projection & Norm"):

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,10 +1233,11 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
12331233
)
12341234

12351235
import time
1236+
12361237
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
12371238
for i, t_val in enumerate(timesteps):
12381239
t = timesteps_jax[i]
1239-
1240+
12401241
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
12411242
latents_jax_sharded = latents_jax
12421243
audio_latents_jax_sharded = audio_latents_jax
@@ -1340,12 +1341,11 @@ def run_connectors(graphdef, state, hidden_states, attention_mask):
13401341
mesh = latents.sharding.mesh
13411342
replicated_sharding = NamedSharding(mesh, P())
13421343
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
1343-
1344+
13441345
# Replicate VAE weights
13451346
graphdef, state = nnx.split(self.vae)
13461347
state = jax.tree_util.tree_map(
1347-
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x,
1348-
state
1348+
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
13491349
)
13501350
self.vae = nnx.merge(graphdef, state)
13511351
except Exception as e:

0 commit comments

Comments
 (0)