@@ -23,9 +23,10 @@ def get_dummy_ltx2_inputs(batch_size, dtype):
2323 timestep = jnp .array (500.0 , dtype = jnp .float32 )
2424 # Gemma dim=3072, sequence=128
2525 prompt_embeds = jax .random .normal (rng , (batch_size , 128 , 3072 ), dtype = dtype )
26- audio_prompt_embeds = None
27- encoder_attention_mask = jnp .ones ((batch_size , 128 ), dtype = jnp .int32 )
28- audio_encoder_attention_mask = None
26+ # LTX-2 Audio latents default channels = 128
27+ audio_latents = jax .random .normal (rng , (batch_size , 64 , 128 ), dtype = dtype )
28+ audio_prompt_embeds = jax .random .normal (rng , (batch_size , 64 , 3072 ), dtype = dtype )
29+ audio_encoder_attention_mask = jnp .ones ((batch_size , 64 ), dtype = jnp .int32 )
2930
3031 return latents , audio_latents , timestep , prompt_embeds , audio_prompt_embeds , encoder_attention_mask , audio_encoder_attention_mask
3132
@@ -83,13 +84,19 @@ def calibrate_fbs(config):
8384
8485 # Add unconditional latents for CFG
8586 double_latents = jnp .concatenate ([latents , latents ], axis = 0 )
87+ double_audio_latents = jnp .concatenate ([audio_latents , audio_latents ], axis = 0 )
8688 double_prompt_embeds = jnp .concatenate ([prompt_embeds , prompt_embeds ], axis = 0 )
89+ double_audio_prompt_embeds = jnp .concatenate ([audio_prompt_embeds , audio_prompt_embeds ], axis = 0 )
8790 double_encoder_attention_mask = jnp .concatenate ([encoder_attention_mask , encoder_attention_mask ], axis = 0 )
91+ double_audio_encoder_attention_mask = jnp .concatenate ([audio_encoder_attention_mask , audio_encoder_attention_mask ], axis = 0 )
8892
8993 double_latents = jax .device_put (double_latents , data_sharding )
94+ double_audio_latents = jax .device_put (double_audio_latents , data_sharding )
9095 timestep = jax .device_put (timestep , data_sharding )
9196 double_prompt_embeds = jax .device_put (double_prompt_embeds , data_sharding )
97+ double_audio_prompt_embeds = jax .device_put (double_audio_prompt_embeds , data_sharding )
9298 double_encoder_attention_mask = jax .device_put (double_encoder_attention_mask , data_sharding )
99+ double_audio_encoder_attention_mask = jax .device_put (double_audio_encoder_attention_mask , data_sharding )
93100
94101 print ("Compiling transformer forward pass..." )
95102 start_compile = time .perf_counter ()
@@ -101,16 +108,16 @@ def calibrate_fbs(config):
101108 latent_num_frames = 16
102109 latent_height = 16
103110 latent_width = 24
104- audio_num_frames = 0
111+ audio_num_frames = 64
105112 fps = 24.0
106113
107114 _ = transformer_forward_pass (
108115 graphdef , sharded_state , double_latents ,
109- None , # audio_latents
116+ double_audio_latents ,
110117 timestep , double_prompt_embeds ,
111- None , # audio_encoder_hidden_states
118+ double_audio_prompt_embeds ,
112119 double_encoder_attention_mask ,
113- None , # audio_encoder_attention_mask
120+ double_audio_encoder_attention_mask ,
114121 do_classifier_free_guidance = True ,
115122 guidance_scale = 1.5 ,
116123 latent_num_frames = latent_num_frames ,
@@ -135,11 +142,11 @@ def calibrate_fbs(config):
135142 start = time .perf_counter ()
136143 _ = transformer_forward_pass (
137144 graphdef , sharded_state , double_latents ,
138- None ,
145+ double_audio_latents ,
139146 timestep , double_prompt_embeds ,
140- None ,
147+ double_audio_prompt_embeds ,
141148 double_encoder_attention_mask ,
142- None ,
149+ double_audio_encoder_attention_mask ,
143150 do_classifier_free_guidance = True ,
144151 guidance_scale = 1.5 ,
145152 latent_num_frames = latent_num_frames ,
0 commit comments