Skip to content

Commit a00b8d9

Browse files
committed
denormalisation fix 2.1
1 parent b32cda7 commit a00b8d9

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,16 @@ def __call__(
236236
)
237237
if self.config.expand_timesteps:
238238
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
239-
latents = self._denormalize_latents(latents)
239+
latents_bcthw = jnp.transpose(latents, (0, 4, 1, 2, 3))
240+
max_logging.log(f"[DEBUG CALL] latents shape before denorm: {latents_bcthw.shape}")
241+
242+
latents_denorm_bcthw = self._denormalize_latents(latents_bcthw)
243+
max_logging.log(f"[DEBUG CALL] latents shape after denorm: {latents_denorm_bcthw.shape}")
244+
240245

241246
if output_type == "latent":
242-
return latents
243-
return self._decode_latents_to_video(latents)
247+
return jnp.transpose(latents_denorm_bcthw, (0, 2, 3, 4, 1))
248+
return self._decode_latents_to_video(latents_denorm_bcthw)
244249

245250
def run_inference_2_1_i2v(
246251
graphdef, sharded_state, rest_of_state,

0 commit comments

Comments
 (0)