Skip to content

Commit b7ee035

Browse files
committed
trying reorder
1 parent f23746b commit b7ee035

1 file changed

Lines changed: 24 additions & 12 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,11 @@ def __call__(
164164
max_logging.log(f"Adjusted num_frames to: {num_frames}")
165165
num_frames = max(num_frames, 1)
166166

167-
prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v(
168-
prompt,
169-
image,
170-
negative_prompt,
171-
num_videos_per_prompt,
172-
max_sequence_length,
173-
prompt_embeds,
174-
negative_prompt_embeds,
175-
image_embeds,
176-
last_image,
177-
)
167+
# Calculate batch size early for prepare_latents dispatch overlap
168+
if prompt is not None and isinstance(prompt, str):
169+
prompt = [prompt]
170+
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt
171+
effective_batch_size = batch_size * num_videos_per_prompt
178172

179173
def _process_image_input(img_input, height, width, num_videos_per_prompt):
180174
if img_input is None:
@@ -187,26 +181,44 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
187181
jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0)
188182
return jax_array
189183

184+
# 1. Dispatch VAE (Step 3) - Prepare Latents
190185
image_tensor = _process_image_input(image, height, width, effective_batch_size)
191186
last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size)
192187

193188
if rng is None:
194189
rng = jax.random.key(self.config.seed)
195190
latents_rng, inference_rng = jax.random.split(rng)
196191

192+
# Use config.activations_dtype since image_embeds is not yet available
193+
latents_dtype = self.config.activations_dtype
194+
197195
latents, condition, first_frame_mask = self.prepare_latents(
198196
image=image_tensor,
199197
batch_size=effective_batch_size,
200198
height=height,
201199
width=width,
202200
num_frames=num_frames,
203-
dtype=image_embeds.dtype,
201+
dtype=latents_dtype,
204202
rng=latents_rng,
205203
latents=latents,
206204
last_image=last_image_tensor,
207205
num_videos_per_prompt=num_videos_per_prompt,
208206
)
209207

208+
# 2. Dispatch Text & CLIP (Steps 1 & 2)
209+
# This might block on CPU/Text encoding, but VAE is already dispatched!
210+
prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v(
211+
prompt,
212+
image,
213+
negative_prompt,
214+
num_videos_per_prompt,
215+
max_sequence_length,
216+
prompt_embeds,
217+
negative_prompt_embeds,
218+
image_embeds,
219+
last_image,
220+
)
221+
210222
scheduler_state = self.scheduler.set_timesteps(
211223
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
212224
)

0 commit comments

Comments
 (0)