Skip to content

Commit 351d345

Browse files
committed
removed redundant function to get dummy inputs
1 parent 18e06e8 commit 351d345

1 file changed

Lines changed: 0 additions & 30 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -162,36 +162,6 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
162162
return transformer
163163

164164

165-
def get_dummy_ltx2_inputs(config, pipeline, batch_size):
166-
# 1. Latents
167-
latents = pipeline.prepare_latents(
168-
batch_size=batch_size,
169-
height=config.resolution,
170-
width=config.resolution,
171-
num_frames=getattr(config, "num_frames", 121),
172-
)
173-
174-
# 2. Audio Latents
175-
audio_latents = pipeline.prepare_audio_latents(
176-
batch_size=batch_size,
177-
audio_latent_length=8,
178-
)
179-
180-
# 3. Embeddings
181-
text_encoder_dim = getattr(pipeline.transformer, "cross_attention_dim", 4096)
182-
encoder_hidden_states = jax.random.normal(jax.random.key(0), (batch_size, 128, text_encoder_dim))
183-
184-
audio_context_dim = getattr(pipeline.transformer, "audio_cross_attention_dim", 2048)
185-
audio_encoder_hidden_states = jax.random.normal(jax.random.key(0), (batch_size, 128, audio_context_dim))
186-
187-
timesteps = jnp.array([0] * batch_size, dtype=jnp.int32)
188-
189-
encoder_attention_mask = jnp.ones((batch_size, 128))
190-
audio_encoder_attention_mask = jnp.ones((batch_size, 128))
191-
192-
return (latents, audio_latents, timesteps, encoder_hidden_states, audio_encoder_hidden_states, encoder_attention_mask, audio_encoder_attention_mask)
193-
194-
195165

196166
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
197167
def calculate_shift(

0 commit comments

Comments
 (0)