@@ -162,36 +162,6 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
162162 return transformer
163163
164164
165- def get_dummy_ltx2_inputs (config , pipeline , batch_size ):
166- # 1. Latents
167- latents = pipeline .prepare_latents (
168- batch_size = batch_size ,
169- height = config .resolution ,
170- width = config .resolution ,
171- num_frames = getattr (config , "num_frames" , 121 ),
172- )
173-
174- # 2. Audio Latents
175- audio_latents = pipeline .prepare_audio_latents (
176- batch_size = batch_size ,
177- audio_latent_length = 8 ,
178- )
179-
180- # 3. Embeddings
181- text_encoder_dim = getattr (pipeline .transformer , "cross_attention_dim" , 4096 )
182- encoder_hidden_states = jax .random .normal (jax .random .key (0 ), (batch_size , 128 , text_encoder_dim ))
183-
184- audio_context_dim = getattr (pipeline .transformer , "audio_cross_attention_dim" , 2048 )
185- audio_encoder_hidden_states = jax .random .normal (jax .random .key (0 ), (batch_size , 128 , audio_context_dim ))
186-
187- timesteps = jnp .array ([0 ] * batch_size , dtype = jnp .int32 )
188-
189- encoder_attention_mask = jnp .ones ((batch_size , 128 ))
190- audio_encoder_attention_mask = jnp .ones ((batch_size , 128 ))
191-
192- return (latents , audio_latents , timesteps , encoder_hidden_states , audio_encoder_hidden_states , encoder_attention_mask , audio_encoder_attention_mask )
193-
194-
195165
196166# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
197167def calculate_shift (
0 commit comments