Skip to content

Commit be231da

Browse files
committed
sharding fix
1 parent 8659856 commit be231da

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ guidance_scale: 3.0
3131
guidance_rescale: 0.7
3232
audio_guidance_scale: 7.0
3333
audio_guidance_rescale: 0.7
34-
stg_scale: 0.0
35-
audio_stg_scale: 0.0
34+
stg_scale: 1.0
35+
audio_stg_scale: 1.0
3636
modality_scale: 1.0
3737
audio_modality_scale: 1.0
3838
use_cross_timestep: true

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def __call__(
388388
if encoder_hidden_states is not None:
389389
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
390390
if audio_encoder_hidden_states is not None:
391-
audio_encoder_hidden_states = jax.lax.with_sharding_constraint(audio_encoder_hidden_states, axis_names)
391+
audio_encoder_hidden_states = jax.lax.with_sharding_constraint(audio_encoder_hidden_states, axis_names_audio)
392392

393393
# 1. Video and Audio Self-Attention
394394
norm_hidden_states = self.norm1(hidden_states)

0 commit comments

Comments
 (0)