Skip to content

Commit dfdbcc0

Browse files
committed
Changed loop body in wan i2v 2.2
1 parent 7610f13 commit dfdbcc0

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,29 +274,35 @@ def loop_body(step, vals):
274274
timestep = jnp.expand_dims(temp_ts, axis=0)
275275
timestep = jnp.broadcast_to(timestep, (latents_input.shape[0], temp_ts.shape[0]))
276276
else:
277-
latent_model_input = jnp.concatenate([latents_input, condition], axis=1)
277+
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
278278
timestep = jnp.broadcast_to(t, latents_input.shape[0])
279279

280280

281281
use_high_noise = jnp.greater_equal(t, boundary)
282282

283283
def high_noise_branch(operands):
284284
latents_input, ts_input, pe_input, ie_input = operands
285-
return transformer_forward_pass(
285+
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))
286+
noise_pred, latents_out = transformer_forward_pass(
286287
high_noise_graphdef, high_noise_state, high_noise_rest,
287288
latents_input, ts_input, pe_input,
288289
do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale,
289290
encoder_hidden_states_image=ie_input
290291
)
292+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
293+
return noise_pred, latents_out
291294

292295
def low_noise_branch(operands):
293296
latents_input, ts_input, pe_input, ie_input = operands
294-
return transformer_forward_pass(
297+
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))
298+
noise_pred, latents_out = transformer_forward_pass(
295299
low_noise_graphdef, low_noise_state, low_noise_rest,
296300
latents_input, ts_input, pe_input,
297301
do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_2,
298302
encoder_hidden_states_image=ie_input
299303
)
304+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
305+
return noise_pred, latents_out
300306

301307
noise_pred, latents = jax.lax.cond(
302308
use_high_noise,

0 commit comments

Comments
 (0)