Skip to content

Commit 712df7c

Browse files
committed
fix in calibration
1 parent 001651b commit 712df7c

1 file changed

Lines changed: 17 additions & 10 deletions

File tree

src/maxdiffusion/scripts/calibrate_ltx2_fbs.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ def get_dummy_ltx2_inputs(batch_size, dtype):
2323
timestep = jnp.array(500.0, dtype=jnp.float32)
2424
# Gemma dim=3072, sequence=128
2525
prompt_embeds = jax.random.normal(rng, (batch_size, 128, 3072), dtype=dtype)
26-
audio_prompt_embeds = None
27-
encoder_attention_mask = jnp.ones((batch_size, 128), dtype=jnp.int32)
28-
audio_encoder_attention_mask = None
26+
# LTX-2 Audio latents default channels = 128
27+
audio_latents = jax.random.normal(rng, (batch_size, 64, 128), dtype=dtype)
28+
audio_prompt_embeds = jax.random.normal(rng, (batch_size, 64, 3072), dtype=dtype)
29+
audio_encoder_attention_mask = jnp.ones((batch_size, 64), dtype=jnp.int32)
2930

3031
return latents, audio_latents, timestep, prompt_embeds, audio_prompt_embeds, encoder_attention_mask, audio_encoder_attention_mask
3132

@@ -83,13 +84,19 @@ def calibrate_fbs(config):
8384

8485
# Add unconditional latents for CFG
8586
double_latents = jnp.concatenate([latents, latents], axis=0)
87+
double_audio_latents = jnp.concatenate([audio_latents, audio_latents], axis=0)
8688
double_prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds], axis=0)
89+
double_audio_prompt_embeds = jnp.concatenate([audio_prompt_embeds, audio_prompt_embeds], axis=0)
8790
double_encoder_attention_mask = jnp.concatenate([encoder_attention_mask, encoder_attention_mask], axis=0)
91+
double_audio_encoder_attention_mask = jnp.concatenate([audio_encoder_attention_mask, audio_encoder_attention_mask], axis=0)
8892

8993
double_latents = jax.device_put(double_latents, data_sharding)
94+
double_audio_latents = jax.device_put(double_audio_latents, data_sharding)
9095
timestep = jax.device_put(timestep, data_sharding)
9196
double_prompt_embeds = jax.device_put(double_prompt_embeds, data_sharding)
97+
double_audio_prompt_embeds = jax.device_put(double_audio_prompt_embeds, data_sharding)
9298
double_encoder_attention_mask = jax.device_put(double_encoder_attention_mask, data_sharding)
99+
double_audio_encoder_attention_mask = jax.device_put(double_audio_encoder_attention_mask, data_sharding)
93100

94101
print("Compiling transformer forward pass...")
95102
start_compile = time.perf_counter()
@@ -101,16 +108,16 @@ def calibrate_fbs(config):
101108
latent_num_frames = 16
102109
latent_height = 16
103110
latent_width = 24
104-
audio_num_frames = 0
111+
audio_num_frames = 64
105112
fps = 24.0
106113

107114
_ = transformer_forward_pass(
108115
graphdef, sharded_state, double_latents,
109-
None, # audio_latents
116+
double_audio_latents,
110117
timestep, double_prompt_embeds,
111-
None, # audio_encoder_hidden_states
118+
double_audio_prompt_embeds,
112119
double_encoder_attention_mask,
113-
None, # audio_encoder_attention_mask
120+
double_audio_encoder_attention_mask,
114121
do_classifier_free_guidance=True,
115122
guidance_scale=1.5,
116123
latent_num_frames=latent_num_frames,
@@ -135,11 +142,11 @@ def calibrate_fbs(config):
135142
start = time.perf_counter()
136143
_ = transformer_forward_pass(
137144
graphdef, sharded_state, double_latents,
138-
None,
145+
double_audio_latents,
139146
timestep, double_prompt_embeds,
140-
None,
147+
double_audio_prompt_embeds,
141148
double_encoder_attention_mask,
142-
None,
149+
double_audio_encoder_attention_mask,
143150
do_classifier_free_guidance=True,
144151
guidance_scale=1.5,
145152
latent_num_frames=latent_num_frames,

0 commit comments

Comments
 (0)